#![allow(dead_code)]
use std::sync::OnceLock;
#[cfg(target_arch = "x86_64")]
#[allow(unused_imports)]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
pub fn has_avx2() -> bool {
is_x86_feature_detected!("avx2")
}
#[cfg(not(target_arch = "x86_64"))]
pub fn has_avx2() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
pub fn has_avx512() -> bool {
is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512bw")
}
#[cfg(not(target_arch = "x86_64"))]
pub fn has_avx512() -> bool {
false
}
#[cfg(target_arch = "aarch64")]
pub fn has_neon() -> bool {
true
}
#[cfg(not(target_arch = "aarch64"))]
pub fn has_neon() -> bool {
false
}
pub const VECTOR_TABLE_BITS: usize = 8;
pub const VECTOR_TABLE_SIZE: usize = 1 << VECTOR_TABLE_BITS;
#[derive(Clone, Copy, Default, Debug)]
#[repr(C)]
pub struct VectorEntry(u16);
impl VectorEntry {
const fn new(symbol: u16, bits: u8) -> Self {
Self(symbol | ((bits as u16) << 9))
}
#[inline(always)]
pub fn symbol(self) -> u16 {
self.0 & 0x1FF
}
#[inline(always)]
pub fn bits(self) -> u8 {
((self.0 >> 9) & 0xF) as u8
}
#[inline(always)]
pub fn is_overflow(self) -> bool {
self.bits() == 0
}
#[inline(always)]
pub fn raw(self) -> u16 {
self.0
}
}
static FIXED_VECTOR_TABLE: OnceLock<Box<[VectorEntry; VECTOR_TABLE_SIZE]>> = OnceLock::new();
fn build_fixed_vector_table() -> Box<[VectorEntry; VECTOR_TABLE_SIZE]> {
let mut table = vec![VectorEntry::default(); VECTOR_TABLE_SIZE];
for sym in 0u16..288 {
let (code, len) = fixed_huffman_code(sym);
if len > 8 {
continue; }
let reversed = reverse_bits(code, len);
let fill_bits = 8 - len as usize;
for suffix in 0..(1 << fill_bits) {
let idx = reversed as usize | (suffix << len as usize);
if idx < VECTOR_TABLE_SIZE {
table[idx] = VectorEntry::new(sym, len);
}
}
}
table.into_boxed_slice().try_into().unwrap()
}
fn fixed_huffman_code(sym: u16) -> (u16, u8) {
match sym {
0..=143 => (0b00110000 + sym, 8),
144..=255 => (0b110010000 + (sym - 144), 9),
256..=279 => (sym - 256, 7),
280..=287 => (0b11000000 + (sym - 280), 8),
_ => (0, 0),
}
}
#[inline(always)]
fn reverse_bits(code: u16, n: u8) -> u16 {
let mut result = 0u16;
let mut c = code;
for _ in 0..n {
result = (result << 1) | (c & 1);
c >>= 1;
}
result
}
#[derive(Clone)]
pub struct VectorTable {
pub table: Box<[VectorEntry; VECTOR_TABLE_SIZE]>,
}
impl VectorTable {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
table: Box::new([VectorEntry::default(); VECTOR_TABLE_SIZE]),
}
}
pub fn build_from_litlen(
&mut self,
litlen: &crate::decompress::inflate::libdeflate_entry::LitLenTable,
) {
for i in 0..VECTOR_TABLE_SIZE {
let entry = litlen.lookup(i as u64);
if entry.is_literal() && entry.codeword_bits() <= 8 {
self.table[i] =
VectorEntry::new(entry.literal_value() as u16, entry.codeword_bits());
} else {
self.table[i] = VectorEntry::new(0, 0);
}
}
}
}
pub fn get_fixed_vector_table() -> &'static [VectorEntry; VECTOR_TABLE_SIZE] {
FIXED_VECTOR_TABLE.get_or_init(build_fixed_vector_table)
}
#[cfg(target_arch = "x86_64")]
mod avx2_impl {
use super::*;
#[target_feature(enable = "avx2")]
pub unsafe fn decode_8_symbols(
bit_buffers: &[u64; 8],
table: &[VectorEntry; VECTOR_TABLE_SIZE],
) -> ([u8; 8], [u8; 8], bool) {
let indices: [u8; 8] = [
(bit_buffers[0] & 0xFF) as u8,
(bit_buffers[1] & 0xFF) as u8,
(bit_buffers[2] & 0xFF) as u8,
(bit_buffers[3] & 0xFF) as u8,
(bit_buffers[4] & 0xFF) as u8,
(bit_buffers[5] & 0xFF) as u8,
(bit_buffers[6] & 0xFF) as u8,
(bit_buffers[7] & 0xFF) as u8,
];
let entries: [VectorEntry; 8] = [
table[indices[0] as usize],
table[indices[1] as usize],
table[indices[2] as usize],
table[indices[3] as usize],
table[indices[4] as usize],
table[indices[5] as usize],
table[indices[6] as usize],
table[indices[7] as usize],
];
let symbols = [
entries[0].symbol() as u8,
entries[1].symbol() as u8,
entries[2].symbol() as u8,
entries[3].symbol() as u8,
entries[4].symbol() as u8,
entries[5].symbol() as u8,
entries[6].symbol() as u8,
entries[7].symbol() as u8,
];
let bits = [
entries[0].bits(),
entries[1].bits(),
entries[2].bits(),
entries[3].bits(),
entries[4].bits(),
entries[5].bits(),
entries[6].bits(),
entries[7].bits(),
];
let any_overflow = bits.contains(&0);
(symbols, bits, any_overflow)
}
#[target_feature(enable = "avx2")]
pub unsafe fn advance_bit_buffers(bit_buffers: &mut [u64; 8], bits_consumed: &[u8; 8]) {
for i in 0..8 {
bit_buffers[i] >>= bits_consumed[i];
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn refill_bit_buffers(
bit_buffers: &mut [u64; 8],
bits_left: &mut [u32; 8],
input: &[u8],
positions: &mut [usize; 8],
) {
for i in 0..8 {
while bits_left[i] <= 56 && positions[i] < input.len() {
bit_buffers[i] |= (input[positions[i]] as u64) << bits_left[i];
positions[i] += 1;
bits_left[i] += 8;
}
}
}
}
#[cfg(target_arch = "aarch64")]
mod neon_impl {
use super::*;
use std::arch::aarch64::*;
pub unsafe fn decode_8_symbols(
bit_buffers: &[u64; 8],
table: &[VectorEntry; VECTOR_TABLE_SIZE],
) -> ([u8; 8], [u8; 8], bool) {
let indices: [u8; 8] = [
(bit_buffers[0] & 0xFF) as u8,
(bit_buffers[1] & 0xFF) as u8,
(bit_buffers[2] & 0xFF) as u8,
(bit_buffers[3] & 0xFF) as u8,
(bit_buffers[4] & 0xFF) as u8,
(bit_buffers[5] & 0xFF) as u8,
(bit_buffers[6] & 0xFF) as u8,
(bit_buffers[7] & 0xFF) as u8,
];
let entries: [VectorEntry; 8] = [
table[indices[0] as usize],
table[indices[1] as usize],
table[indices[2] as usize],
table[indices[3] as usize],
table[indices[4] as usize],
table[indices[5] as usize],
table[indices[6] as usize],
table[indices[7] as usize],
];
let symbols = [
entries[0].symbol() as u8,
entries[1].symbol() as u8,
entries[2].symbol() as u8,
entries[3].symbol() as u8,
entries[4].symbol() as u8,
entries[5].symbol() as u8,
entries[6].symbol() as u8,
entries[7].symbol() as u8,
];
let bits = [
entries[0].bits(),
entries[1].bits(),
entries[2].bits(),
entries[3].bits(),
entries[4].bits(),
entries[5].bits(),
entries[6].bits(),
entries[7].bits(),
];
let bits_vec = vld1_u8(bits.as_ptr());
let zero = vdup_n_u8(0);
let cmp = vceq_u8(bits_vec, zero);
let any_overflow = vget_lane_u64(vreinterpret_u64_u8(cmp), 0) != 0;
(symbols, bits, any_overflow)
}
pub unsafe fn advance_bit_buffers(bit_buffers: &mut [u64; 8], bits_consumed: &[u8; 8]) {
for i in 0..8 {
bit_buffers[i] >>= bits_consumed[i];
}
}
pub unsafe fn refill_bit_buffers(
bit_buffers: &mut [u64; 8],
bits_left: &mut [u32; 8],
input: &[u8],
positions: &mut [usize; 8],
) {
for i in 0..8 {
while bits_left[i] <= 56 && positions[i] < input.len() {
bit_buffers[i] |= (input[positions[i]] as u64) << bits_left[i];
positions[i] += 1;
bits_left[i] += 8;
}
}
}
}
pub struct VectorLanes {
pub bit_buffers: [u64; 8],
pub bits_left: [u32; 8],
pub input_positions: [usize; 8],
pub output_positions: [usize; 8],
pub lane_overflow: [bool; 8],
}
impl VectorLanes {
pub fn new(input_len: usize, output_hint: usize) -> Self {
let lane_spacing = input_len / 8;
let output_spacing = output_hint / 8;
let mut lanes = Self {
bit_buffers: [0; 8],
bits_left: [0; 8],
input_positions: [0; 8],
output_positions: [0; 8],
lane_overflow: [false; 8],
};
for i in 0..8 {
lanes.input_positions[i] = i * lane_spacing;
lanes.output_positions[i] = i * output_spacing;
}
lanes
}
pub fn all_overflow(&self) -> bool {
self.lane_overflow.iter().all(|&x| x)
}
pub fn active_lanes(&self) -> usize {
self.lane_overflow.iter().filter(|&&x| !x).count()
}
}
#[cfg(target_arch = "x86_64")]
pub fn decode_fixed_vector(
input: &[u8],
output: &mut [u8],
output_size_hint: usize,
) -> std::io::Result<usize> {
if !has_avx2() {
return Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"AVX2 not available",
));
}
let table = get_fixed_vector_table();
let mut lanes = VectorLanes::new(input.len(), output_size_hint);
unsafe {
avx2_impl::refill_bit_buffers(
&mut lanes.bit_buffers,
&mut lanes.bits_left,
input,
&mut lanes.input_positions,
);
}
let mut total_decoded = 0;
let max_iterations = output.len();
for _ in 0..max_iterations {
let (symbols, bits_consumed, any_overflow) =
unsafe { avx2_impl::decode_8_symbols(&lanes.bit_buffers, table) };
if any_overflow {
break;
}
for (i, &symbol) in symbols.iter().enumerate() {
if !lanes.lane_overflow[i] {
let out_pos = lanes.output_positions[i];
if out_pos < output.len() && symbol < 0xFF {
output[out_pos] = symbol;
lanes.output_positions[i] += 1;
total_decoded += 1;
}
}
}
unsafe {
avx2_impl::advance_bit_buffers(&mut lanes.bit_buffers, &bits_consumed);
}
unsafe {
avx2_impl::refill_bit_buffers(
&mut lanes.bit_buffers,
&mut lanes.bits_left,
input,
&mut lanes.input_positions,
);
}
for (i, &consumed) in bits_consumed.iter().enumerate() {
lanes.bits_left[i] = lanes.bits_left[i].saturating_sub(consumed as u32);
}
}
Ok(total_decoded)
}
#[cfg(target_arch = "aarch64")]
pub fn decode_fixed_vector(
input: &[u8],
output: &mut [u8],
output_size_hint: usize,
) -> std::io::Result<usize> {
let table = get_fixed_vector_table();
let mut lanes = VectorLanes::new(input.len(), output_size_hint);
unsafe {
neon_impl::refill_bit_buffers(
&mut lanes.bit_buffers,
&mut lanes.bits_left,
input,
&mut lanes.input_positions,
);
}
let mut total_decoded = 0;
let max_iterations = output.len();
for _ in 0..max_iterations {
let (symbols, bits_consumed, any_overflow) =
unsafe { neon_impl::decode_8_symbols(&lanes.bit_buffers, table) };
if any_overflow {
break;
}
#[allow(clippy::needless_range_loop)]
for i in 0..8 {
if !lanes.lane_overflow[i] {
let out_pos = lanes.output_positions[i];
if out_pos < output.len() && symbols[i] < 0xFF {
output[out_pos] = symbols[i];
lanes.output_positions[i] += 1;
total_decoded += 1;
}
}
}
unsafe {
neon_impl::advance_bit_buffers(&mut lanes.bit_buffers, &bits_consumed);
}
unsafe {
neon_impl::refill_bit_buffers(
&mut lanes.bit_buffers,
&mut lanes.bits_left,
input,
&mut lanes.input_positions,
);
}
#[allow(clippy::needless_range_loop)]
for i in 0..8 {
lanes.bits_left[i] = lanes.bits_left[i].saturating_sub(bits_consumed[i] as u32);
}
}
Ok(total_decoded)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
pub fn decode_fixed_vector(
_input: &[u8],
_output: &mut [u8],
_output_size_hint: usize,
) -> std::io::Result<usize> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Vector decode only available on x86_64 or aarch64",
))
}
#[inline(always)]
pub fn decode_multi_literals(
bitbuf: u64,
table: &[VectorEntry; VECTOR_TABLE_SIZE],
) -> ([u8; 4], usize, u32) {
let mut symbols = [0u8; 4];
let mut bits_consumed = 0u32;
let mut remaining = bitbuf;
let mut count = 0usize;
#[allow(clippy::needless_range_loop)]
for i in 0..4 {
let idx = (remaining & 0xFF) as usize;
let entry = table[idx];
if entry.is_overflow() {
break;
}
let sym = entry.symbol();
let bits = entry.bits() as u32;
if sym >= 256 {
break;
}
symbols[i] = sym as u8;
remaining >>= bits;
bits_consumed += bits;
count = i + 1;
}
(symbols, count, bits_consumed)
}
pub fn decode_fixed_multi_literal_bits(
bits: &mut crate::decompress::inflate::consume_first_decode::Bits,
output: &mut [u8],
mut out_pos: usize,
) -> std::io::Result<usize> {
let table = get_fixed_vector_table();
let fixed_tables = crate::decompress::inflate::libdeflate_decode::get_fixed_tables();
loop {
if bits.available() < 32 {
bits.refill();
}
let (symbols, count, bits_count) = decode_multi_literals(bits.peek(), table);
if count > 0 {
if out_pos + count > output.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"Output full",
));
}
output[out_pos..(count + out_pos)].copy_from_slice(&symbols[..count]);
out_pos += count;
bits.consume(bits_count);
continue;
}
if bits.available() < 15 {
bits.refill();
}
let saved = bits.peek();
let entry = fixed_tables.0.lookup(saved);
bits.consume_entry(entry.raw());
if entry.is_end_of_block() {
return Ok(out_pos);
}
if entry.is_literal() {
if out_pos >= output.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"Output full",
));
}
output[out_pos] = entry.literal_value();
out_pos += 1;
} else {
let length = entry.decode_length(saved);
if bits.available() < 15 {
bits.refill();
}
let dist_saved = bits.peek();
let dist_entry = fixed_tables.1.lookup(dist_saved);
bits.consume_entry(dist_entry.raw());
let distance = dist_entry.decode_distance(dist_saved);
if distance == 0 || distance as usize > out_pos {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Invalid distance {} at pos {}", distance, out_pos),
));
}
let dist = distance as usize;
let len = length as usize;
if out_pos + len > output.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"Output full",
));
}
for i in 0..len {
output[out_pos + i] = output[out_pos - dist + i];
}
out_pos += len;
}
}
}
pub fn decode_fixed_multi_literal(input: &[u8], output: &mut [u8]) -> std::io::Result<usize> {
let mut bits = crate::decompress::inflate::consume_first_decode::Bits::new(input);
decode_fixed_multi_literal_bits(&mut bits, output, 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_table_builds() {
let table = get_fixed_vector_table();
let valid = table.iter().filter(|e| !e.is_overflow()).count();
let overflow = table.iter().filter(|e| e.is_overflow()).count();
eprintln!("\nVector table statistics:");
eprintln!(" Valid entries: {}", valid);
eprintln!(" Overflow entries: {}", overflow);
eprintln!(" Coverage: {:.1}%", 100.0 * valid as f64 / 256.0);
assert!(
valid >= 200,
"Should have >=200 valid entries, got {}",
valid
);
}
#[test]
fn test_has_simd() {
eprintln!("\nSIMD availability:");
eprintln!(" AVX2: {}", has_avx2());
eprintln!(" AVX-512: {}", has_avx512());
eprintln!(" NEON: {}", has_neon());
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_decode_8_symbols() {
if !has_avx2() {
eprintln!("Skipping AVX2 test - not available");
return;
}
let table = get_fixed_vector_table();
let bit_buffers = [0x86u64; 8];
let (symbols, bits, any_overflow) =
unsafe { avx2_impl::decode_8_symbols(&bit_buffers, table) };
eprintln!("\nDecode 8 symbols test:");
eprintln!(" Symbols: {:?}", symbols);
eprintln!(" Bits: {:?}", bits);
eprintln!(" Any overflow: {}", any_overflow);
for i in 0..8 {
assert_eq!(symbols[i], symbols[0], "Lane {} mismatch", i);
assert_eq!(bits[i], bits[0], "Lane {} bits mismatch", i);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_decode_8_symbols_neon() {
let table = get_fixed_vector_table();
let bit_buffers = [0x86u64; 8];
let (symbols, bits, any_overflow) =
unsafe { super::neon_impl::decode_8_symbols(&bit_buffers, table) };
eprintln!("\nNEON Decode 8 symbols test:");
eprintln!(" Symbols: {:?}", symbols);
eprintln!(" Bits: {:?}", bits);
eprintln!(" Any overflow: {}", any_overflow);
for i in 0..8 {
assert_eq!(symbols[i], symbols[0], "Lane {} mismatch", i);
assert_eq!(bits[i], bits[0], "Lane {} bits mismatch", i);
}
}
#[test]
fn bench_vector_decode() {
let mut output = vec![0u8; 10000];
let input = vec![0x86u8; 1000];
let result = decode_fixed_vector(&input, &mut output, 10000);
match result {
Ok(decoded) => {
eprintln!("\nVector decode test:");
eprintln!(" Decoded {} bytes", decoded);
}
Err(e) => {
eprintln!("\nVector decode not available: {}", e);
}
}
}
#[test]
fn test_multi_literal_decode() {
let table = get_fixed_vector_table();
let bitbuf = 0x86_86_86_86u64;
let (symbols, count, bits) = decode_multi_literals(bitbuf, table);
eprintln!("\nMulti-literal decode test:");
eprintln!(" Symbols: {:?}", &symbols[..count]);
eprintln!(" Count: {}", count);
eprintln!(" Bits consumed: {}", bits);
assert!(count >= 1, "Should decode at least 1 symbol");
assert!(bits > 0, "Should consume some bits");
}
#[test]
fn bench_multi_literal() {
let table = get_fixed_vector_table();
let iterations = 10_000_000u64;
let patterns = [
0x86_86_86_86u64,
0x87_86_87_86u64,
0x88_88_88_88u64,
0x89_89_89_89u64,
];
let start = std::time::Instant::now();
let mut total_count = 0u64;
let mut total_bits = 0u64;
for i in 0..iterations {
let bitbuf = patterns[(i & 3) as usize].wrapping_add(i);
let (_, count, bits) = decode_multi_literals(bitbuf, table);
total_count += count as u64;
total_bits += bits as u64;
}
let elapsed = start.elapsed();
let per_sec = iterations as f64 / elapsed.as_secs_f64();
let symbols_per_sec = (total_count as f64) / elapsed.as_secs_f64();
eprintln!("\nMulti-literal benchmark:");
eprintln!(
" {} iterations in {:.2}ms",
iterations,
elapsed.as_secs_f64() * 1000.0
);
eprintln!(" {:.1} M decodes/sec", per_sec / 1_000_000.0);
eprintln!(" {:.1} M symbols/sec", symbols_per_sec / 1_000_000.0);
eprintln!(
" Avg symbols/decode: {:.2}",
total_count as f64 / iterations as f64
);
eprintln!(
" Avg bits/decode: {:.2}",
total_bits as f64 / iterations as f64
);
}
#[test]
fn bench_vector_table_lookup() {
let table = get_fixed_vector_table();
let iterations = 10_000_000u64;
let start = std::time::Instant::now();
let mut sum = 0u64;
for i in 0..iterations {
let entry = table[(i & 0xFF) as usize];
sum = sum.wrapping_add(entry.symbol() as u64);
}
let elapsed = start.elapsed();
let lookups_per_sec = iterations as f64 / elapsed.as_secs_f64();
eprintln!("\nVector table lookup benchmark:");
eprintln!(
" {} lookups in {:.2}ms",
iterations,
elapsed.as_secs_f64() * 1000.0
);
eprintln!(" {:.1} M lookups/sec", lookups_per_sec / 1_000_000.0);
eprintln!(" (sum: {} to prevent opt)", sum);
}
}