use alloc::vec::Vec;
pub(crate) const HUF_BITS_IN_CONTAINER: usize = 64;
pub(crate) const HUF_TABLELOG_ABSOLUTEMAX: usize = 12;
#[inline(always)]
pub(crate) fn pack_huf_celt(value: u32, nb_bits: u8) -> u64 {
debug_assert!((nb_bits as usize) <= HUF_TABLELOG_ABSOLUTEMAX);
if nb_bits == 0 {
return 0;
}
let nb = nb_bits as u64;
debug_assert!((value as u64) >> nb == 0, "value must fit in nb_bits");
nb | ((value as u64) << (HUF_BITS_IN_CONTAINER as u64 - nb))
}
pub(crate) struct HufCStream<'a> {
bit_container: [u64; 2],
bit_pos: [u64; 2],
output: &'a mut Vec<u8>,
start_idx: usize,
cursor: usize,
end_ptr: usize,
overflow: bool,
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[inline]
fn huf_encode_use_bmi2() -> bool {
std::arch::is_x86_feature_detected!("bmi2")
}
#[cfg(all(not(feature = "std"), any(target_arch = "x86", target_arch = "x86_64")))]
#[inline]
fn huf_encode_use_bmi2() -> bool {
cfg!(target_feature = "bmi2")
}
macro_rules! huf_add {
($bc:ident, $bp:ident, $elt:expr, $fast:expr) => {{
let elt = $elt;
let nb_bits = elt & 0xFF;
$bc >>= nb_bits;
$bc |= if $fast { elt } else { elt & !0xFFu64 };
$bp = $bp.wrapping_add(if $fast { elt } else { nb_bits });
}};
}
macro_rules! huf_flush {
($bc:ident, $bp:ident, $cursor:ident, $overflow:ident, $out_base:ident, $end_ptr:ident, $fast:expr) => {{
let nb_bits = ($bp & 0xFF) as usize;
let nb_bytes = nb_bits >> 3;
let chunk = if nb_bits == 0 {
0
} else {
$bc >> (HUF_BITS_IN_CONTAINER - nb_bits)
};
$bp &= 7;
let bytes = chunk.to_le_bytes();
unsafe {
core::ptr::copy_nonoverlapping(bytes.as_ptr(), $out_base.add($cursor), 8);
}
$cursor += nb_bytes;
if !$fast && $cursor > $end_ptr {
$cursor = $end_ptr;
$overflow = true;
}
}};
}
macro_rules! encode_unrolled_body {
($self:expr, $table:expr, $data:expr, $ku:expr, $kff:expr, $klf:expr) => {{
let mut bc0 = $self.bit_container[0];
let mut bc1 = $self.bit_container[1];
let mut bp0 = $self.bit_pos[0];
let mut bp1 = $self.bit_pos[1];
let mut cursor = $self.cursor;
let mut overflow = $self.overflow;
let end_ptr = $self.end_ptr;
let out_base = $self.output.as_mut_ptr();
let mut n = $data.len();
let rem = n % $ku;
if rem > 0 {
for _ in 0..rem {
n -= 1;
huf_add!(bc0, bp0, $table[$data[n] as usize], false);
}
huf_flush!(bc0, bp0, cursor, overflow, out_base, end_ptr, $kff);
}
debug_assert!(n.is_multiple_of($ku));
if !n.is_multiple_of(2 * $ku) {
for u in 1..$ku {
huf_add!(bc0, bp0, $table[$data[n - u] as usize], true);
}
huf_add!(bc0, bp0, $table[$data[n - $ku] as usize], $klf);
huf_flush!(bc0, bp0, cursor, overflow, out_base, end_ptr, $kff);
n -= $ku;
}
debug_assert!(n.is_multiple_of(2 * $ku));
while n > 0 {
for u in 1..$ku {
huf_add!(bc0, bp0, $table[$data[n - u] as usize], true);
}
huf_add!(bc0, bp0, $table[$data[n - $ku] as usize], $klf);
huf_flush!(bc0, bp0, cursor, overflow, out_base, end_ptr, $kff);
bc1 = 0;
bp1 = 0;
for u in 1..$ku {
huf_add!(bc1, bp1, $table[$data[n - $ku - u] as usize], true);
}
huf_add!(bc1, bp1, $table[$data[n - $ku - $ku] as usize], $klf);
let nb_bits_1 = bp1 & 0xFF;
bc0 >>= nb_bits_1;
bc0 |= bc1;
bp0 = bp0.wrapping_add(bp1);
huf_flush!(bc0, bp0, cursor, overflow, out_base, end_ptr, $kff);
n -= 2 * $ku;
}
debug_assert_eq!(n, 0);
$self.bit_container[0] = bc0;
$self.bit_container[1] = bc1;
$self.bit_pos[0] = bp0;
$self.bit_pos[1] = bp1;
$self.cursor = cursor;
$self.overflow = overflow;
}};
}
impl<'a> HufCStream<'a> {
pub(crate) fn new(output: &'a mut Vec<u8>, dst_capacity: usize) -> Option<Self> {
if dst_capacity <= 8 {
return None;
}
let start_idx = output.len();
output.reserve(dst_capacity);
Some(Self {
bit_container: [0, 0],
bit_pos: [0, 0],
output,
start_idx,
cursor: start_idx,
end_ptr: start_idx + dst_capacity - 8,
overflow: false,
})
}
#[inline(always)]
pub(crate) fn add_bits<const FAST: bool>(&mut self, elt: u64, idx: usize) {
debug_assert!(idx <= 1);
let nb_bits = elt & 0xFF;
debug_assert!((nb_bits as usize) <= HUF_TABLELOG_ABSOLUTEMAX);
self.bit_container[idx] >>= nb_bits;
let value = if FAST { elt } else { elt & !0xFFu64 };
self.bit_container[idx] |= value;
let nb_add = if FAST { elt } else { nb_bits };
self.bit_pos[idx] = self.bit_pos[idx].wrapping_add(nb_add);
}
#[inline(always)]
pub(crate) fn flush_bits<const FAST: bool>(&mut self) {
let nb_bits = (self.bit_pos[0] & 0xFF) as usize;
let nb_bytes = nb_bits >> 3;
let bit_container = if nb_bits == 0 {
0
} else {
self.bit_container[0] >> (HUF_BITS_IN_CONTAINER - nb_bits)
};
self.bit_pos[0] &= 7;
let bytes = bit_container.to_le_bytes();
unsafe {
let dst = self.output.as_mut_ptr().add(self.cursor);
core::ptr::copy_nonoverlapping(bytes.as_ptr(), dst, 8);
}
self.cursor += nb_bytes;
if !FAST && self.cursor > self.end_ptr {
self.cursor = self.end_ptr;
self.overflow = true;
}
}
#[inline]
pub(crate) fn encode_unrolled<
const K_UNROLL: usize,
const K_FAST_FLUSH: bool,
const K_LAST_FAST: bool,
>(
&mut self,
table: &[u64],
data: &[u8],
) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if huf_encode_use_bmi2() {
unsafe {
self.encode_unrolled_bmi2::<K_UNROLL, K_FAST_FLUSH, K_LAST_FAST>(table, data);
}
return;
}
}
self.encode_unrolled_scalar::<K_UNROLL, K_FAST_FLUSH, K_LAST_FAST>(table, data);
}
#[inline]
fn encode_unrolled_scalar<
const K_UNROLL: usize,
const K_FAST_FLUSH: bool,
const K_LAST_FAST: bool,
>(
&mut self,
table: &[u64],
data: &[u8],
) {
encode_unrolled_body!(self, table, data, K_UNROLL, K_FAST_FLUSH, K_LAST_FAST);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "bmi2")]
unsafe fn encode_unrolled_bmi2<
const K_UNROLL: usize,
const K_FAST_FLUSH: bool,
const K_LAST_FAST: bool,
>(
&mut self,
table: &[u64],
data: &[u8],
) {
encode_unrolled_body!(self, table, data, K_UNROLL, K_FAST_FLUSH, K_LAST_FAST);
}
#[inline(always)]
pub(crate) fn pending_bits(&self) -> usize {
(self.bit_pos[0] & 0xFF) as usize
}
pub(crate) fn close(mut self) -> usize {
let end_mark: u64 = 1u64 | (1u64 << (HUF_BITS_IN_CONTAINER as u64 - 1));
self.add_bits::<false>(end_mark, 0);
self.flush_bits::<false>();
let nb_bits = self.pending_bits();
if self.overflow {
return 0;
}
let bytes_written = (self.cursor - self.start_idx) + usize::from(nb_bits > 0);
unsafe {
self.output.set_len(self.start_idx + bytes_written);
}
bytes_written
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_bits_single_symbol_emits_correct_byte() {
let mut out: Vec<u8> = Vec::new();
let mut s = HufCStream::new(&mut out, 64).expect("init ok");
let elt = pack_huf_celt(0b1011, 4);
s.add_bits::<false>(elt, 0);
let n = s.close();
assert!(n > 0);
assert_eq!(out.len(), 1);
assert_eq!(
out[0], 0x1B,
"first emitted byte must mirror upstream zstd's HUF_addBits + \
HUF_endMark packing collapsed to a 5-bit prefix 0b11011",
);
}
#[test]
fn add_bits_overflowing_container_flushes_correctly() {
let mut out: Vec<u8> = Vec::new();
let mut s = HufCStream::new(&mut out, 256).expect("init ok");
for i in 0..8 {
let elt = pack_huf_celt(i as u32, 8);
s.add_bits::<false>(elt, 0);
}
s.flush_bits::<false>();
assert_eq!(s.cursor - s.start_idx, 8);
assert_eq!(s.pending_bits(), 0);
let n = s.close();
assert!(n >= 8);
}
#[test]
fn encode_unrolled_dual_container_size_is_deterministic() {
let mut out: Vec<u8> = Vec::new();
let mut s = HufCStream::new(&mut out, 64).expect("init ok");
let table = [pack_huf_celt(0b1010, 4); 256];
let data = [0u8; 16];
s.encode_unrolled::<4, false, false>(&table, &data);
let n = s.close();
assert_eq!(
n, 9,
"16 symbols * 4 bits + 1 end-mark bit = 65 bits = 9 bytes"
);
}
}