use alloc::vec::Vec;
pub(crate) const HUF_BITS_IN_CONTAINER: usize = 64;
pub(crate) const HUF_TABLELOG_ABSOLUTEMAX: usize = 12;
#[inline(always)]
pub(crate) fn pack_huf_celt(value: u32, nb_bits: u8) -> u64 {
debug_assert!((nb_bits as usize) <= HUF_TABLELOG_ABSOLUTEMAX);
if nb_bits == 0 {
return 0;
}
let nb = nb_bits as u64;
debug_assert!((value as u64) >> nb == 0, "value must fit in nb_bits");
nb | ((value as u64) << (HUF_BITS_IN_CONTAINER as u64 - nb))
}
pub(crate) struct HufCStream<'a> {
bit_container: [u64; 2],
bit_pos: [u64; 2],
output: &'a mut Vec<u8>,
start_idx: usize,
cursor: usize,
end_ptr: usize,
overflow: bool,
}
impl<'a> HufCStream<'a> {
pub(crate) fn new(output: &'a mut Vec<u8>, dst_capacity: usize) -> Option<Self> {
if dst_capacity <= 8 {
return None;
}
let start_idx = output.len();
output.reserve(dst_capacity);
Some(Self {
bit_container: [0, 0],
bit_pos: [0, 0],
output,
start_idx,
cursor: start_idx,
end_ptr: start_idx + dst_capacity - 8,
overflow: false,
})
}
#[inline(always)]
pub(crate) fn add_bits<const FAST: bool>(&mut self, elt: u64, idx: usize) {
debug_assert!(idx <= 1);
let nb_bits = elt & 0xFF;
debug_assert!((nb_bits as usize) <= HUF_TABLELOG_ABSOLUTEMAX);
self.bit_container[idx] >>= nb_bits;
let value = if FAST { elt } else { elt & !0xFFu64 };
self.bit_container[idx] |= value;
let nb_add = if FAST { elt } else { nb_bits };
self.bit_pos[idx] = self.bit_pos[idx].wrapping_add(nb_add);
}
#[inline(always)]
pub(crate) fn flush_bits<const FAST: bool>(&mut self) {
let nb_bits = (self.bit_pos[0] & 0xFF) as usize;
let nb_bytes = nb_bits >> 3;
let bit_container = if nb_bits == 0 {
0
} else {
self.bit_container[0] >> (HUF_BITS_IN_CONTAINER - nb_bits)
};
self.bit_pos[0] &= 7;
let bytes = bit_container.to_le_bytes();
unsafe {
let dst = self.output.as_mut_ptr().add(self.cursor);
core::ptr::copy_nonoverlapping(bytes.as_ptr(), dst, 8);
}
self.cursor += nb_bytes;
if !FAST && self.cursor > self.end_ptr {
self.cursor = self.end_ptr;
self.overflow = true;
}
}
#[inline]
pub(crate) fn encode_unrolled<
const K_UNROLL: usize,
const K_FAST_FLUSH: bool,
const K_LAST_FAST: bool,
>(
&mut self,
table: &[u64],
data: &[u8],
) {
let mut bc0 = self.bit_container[0];
let mut bc1 = self.bit_container[1];
let mut bp0 = self.bit_pos[0];
let mut bp1 = self.bit_pos[1];
let mut cursor = self.cursor;
let mut overflow = self.overflow;
let end_ptr = self.end_ptr;
let out_base = self.output.as_mut_ptr();
macro_rules! add0 {
($elt:expr, $fast:expr) => {{
let elt = $elt;
let nb_bits = elt & 0xFF;
bc0 >>= nb_bits;
bc0 |= if $fast { elt } else { elt & !0xFFu64 };
bp0 = bp0.wrapping_add(if $fast { elt } else { nb_bits });
}};
}
macro_rules! add1 {
($elt:expr, $fast:expr) => {{
let elt = $elt;
let nb_bits = elt & 0xFF;
bc1 >>= nb_bits;
bc1 |= if $fast { elt } else { elt & !0xFFu64 };
bp1 = bp1.wrapping_add(if $fast { elt } else { nb_bits });
}};
}
macro_rules! flush0 {
($fast:expr) => {{
let nb_bits = (bp0 & 0xFF) as usize;
let nb_bytes = nb_bits >> 3;
let chunk = if nb_bits == 0 {
0
} else {
bc0 >> (HUF_BITS_IN_CONTAINER - nb_bits)
};
bp0 &= 7;
let bytes = chunk.to_le_bytes();
unsafe {
core::ptr::copy_nonoverlapping(bytes.as_ptr(), out_base.add(cursor), 8);
}
cursor += nb_bytes;
if !$fast && cursor > end_ptr {
cursor = end_ptr;
overflow = true;
}
}};
}
let mut n = data.len();
let rem = n % K_UNROLL;
if rem > 0 {
for _ in 0..rem {
n -= 1;
add0!(table[data[n] as usize], false);
}
flush0!(K_FAST_FLUSH);
}
debug_assert!(n.is_multiple_of(K_UNROLL));
if !n.is_multiple_of(2 * K_UNROLL) {
for u in 1..K_UNROLL {
add0!(table[data[n - u] as usize], true);
}
add0!(table[data[n - K_UNROLL] as usize], K_LAST_FAST);
flush0!(K_FAST_FLUSH);
n -= K_UNROLL;
}
debug_assert!(n.is_multiple_of(2 * K_UNROLL));
while n > 0 {
for u in 1..K_UNROLL {
add0!(table[data[n - u] as usize], true);
}
add0!(table[data[n - K_UNROLL] as usize], K_LAST_FAST);
flush0!(K_FAST_FLUSH);
bc1 = 0;
bp1 = 0;
for u in 1..K_UNROLL {
add1!(table[data[n - K_UNROLL - u] as usize], true);
}
add1!(table[data[n - K_UNROLL - K_UNROLL] as usize], K_LAST_FAST);
let nb_bits_1 = bp1 & 0xFF;
bc0 >>= nb_bits_1;
bc0 |= bc1;
bp0 = bp0.wrapping_add(bp1);
flush0!(K_FAST_FLUSH);
n -= 2 * K_UNROLL;
}
debug_assert_eq!(n, 0);
self.bit_container[0] = bc0;
self.bit_container[1] = bc1;
self.bit_pos[0] = bp0;
self.bit_pos[1] = bp1;
self.cursor = cursor;
self.overflow = overflow;
}
#[inline(always)]
pub(crate) fn pending_bits(&self) -> usize {
(self.bit_pos[0] & 0xFF) as usize
}
pub(crate) fn close(mut self) -> usize {
let end_mark: u64 = 1u64 | (1u64 << (HUF_BITS_IN_CONTAINER as u64 - 1));
self.add_bits::<false>(end_mark, 0);
self.flush_bits::<false>();
let nb_bits = self.pending_bits();
if self.overflow {
return 0;
}
let bytes_written = (self.cursor - self.start_idx) + usize::from(nb_bits > 0);
unsafe {
self.output.set_len(self.start_idx + bytes_written);
}
bytes_written
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_bits_single_symbol_emits_correct_byte() {
let mut out: Vec<u8> = Vec::new();
let mut s = HufCStream::new(&mut out, 64).expect("init ok");
let elt = pack_huf_celt(0b1011, 4);
s.add_bits::<false>(elt, 0);
let n = s.close();
assert!(n > 0);
assert_eq!(out.len(), 1);
assert_eq!(
out[0], 0x1B,
"first emitted byte must mirror upstream zstd's HUF_addBits + \
HUF_endMark packing collapsed to a 5-bit prefix 0b11011",
);
}
#[test]
fn add_bits_overflowing_container_flushes_correctly() {
let mut out: Vec<u8> = Vec::new();
let mut s = HufCStream::new(&mut out, 256).expect("init ok");
for i in 0..8 {
let elt = pack_huf_celt(i as u32, 8);
s.add_bits::<false>(elt, 0);
}
s.flush_bits::<false>();
assert_eq!(s.cursor - s.start_idx, 8);
assert_eq!(s.pending_bits(), 0);
let n = s.close();
assert!(n >= 8);
}
#[test]
fn encode_unrolled_dual_container_size_is_deterministic() {
let mut out: Vec<u8> = Vec::new();
let mut s = HufCStream::new(&mut out, 64).expect("init ok");
let table = [pack_huf_celt(0b1010, 4); 256];
let data = [0u8; 16];
s.encode_unrolled::<4, false, false>(&table, &data);
let n = s.close();
assert_eq!(
n, 9,
"16 symbols * 4 bits + 1 end-mark bit = 65 bits = 9 bytes"
);
}
}