linflate 0.1.0

Fast pure-Rust DEFLATE decompressor — SIMD match-copy, branchless refill, segment-aware
Documentation
//! Branchless 64-bit DEFLATE bit reader.
//!
//! Design: libdeflate/zlib-ng style — 64-bit shift register with branchless
//! refill in 4 instructions. 32-byte struct fits one cache line.
//!
//! After `refill()`, at least 56 bits are available. A single DEFLATE symbol
//! (litlen + extra + dist + extra) needs at most 48 bits, so one refill per
//! symbol is always sufficient.

/// Branchless bit reader for DEFLATE decompression.
///
/// Reads bits LSB-first from a byte stream. All methods are `#[inline(always)]`
/// for the hot decode loop.
pub struct BitReader {
    /// Current read position in the input.
    ptr: *const u8,
    /// One-past-end of safe read zone (input end minus 7 for overread safety).
    safe_end: *const u8,
    /// Absolute end of input.
    end: *const u8,
    /// 64-bit shift register, LSB = next bit to consume.
    buf: u64,
    /// Number of valid bits currently in `buf` (0..=63).
    bits: u32,
}

// SAFETY: BitReader holds raw pointers into caller-provided slices.
// The caller guarantees the slice outlives the BitReader.
unsafe impl Send for BitReader {}

impl BitReader {
    /// Create a new BitReader over the given compressed data.
    ///
    /// The input slice must remain valid for the lifetime of the BitReader.
    #[inline]
    pub fn new(input: &[u8]) -> Self {
        let ptr = input.as_ptr();
        let end = unsafe { ptr.add(input.len()) };
        // safe_end = end - 7 (or ptr if input < 8 bytes) — we can always
        // read 8 bytes from ptr when ptr < safe_end without overread.
        let safe_end = if input.len() >= 8 {
            unsafe { end.sub(7) }
        } else {
            ptr
        };
        Self {
            ptr,
            safe_end,
            end,
            buf: 0,
            bits: 0,
        }
    }

    /// Branchless refill: guarantee at least 56 valid bits in `buf`.
    ///
    /// Uses the zlib-ng/libdeflate XOR trick:
    /// - Load 8 bytes (unaligned) and OR into buf shifted by current bit count
    /// - Advance ptr by `(63 ^ bits) >> 3` bytes (branchless)
    /// - Set bits |= 56
    ///
    /// 4 instructions, no branch. After refill: `self.bits >= 56`.
    #[inline(always)]
    pub unsafe fn refill(&mut self) {
        debug_assert!(self.bits <= 63);
        if self.ptr < self.safe_end {
            unsafe {
                let raw = core::ptr::read_unaligned(self.ptr as *const u64);
                self.buf |= u64::from_le(raw) << (self.bits as u8);
                let advance = ((63 ^ self.bits) >> 3) as usize;
                self.ptr = self.ptr.add(advance);
            }
            self.bits |= 56;
        } else {
            self.refill_slow();
        }
    }

    /// Slow-path refill for the last few bytes of input.
    #[cold]
    #[inline(never)]
    fn refill_slow(&mut self) {
        while self.bits <= 56 && self.ptr < self.end {
            self.buf |= (unsafe { *self.ptr } as u64) << self.bits;
            self.ptr = unsafe { self.ptr.add(1) };
            self.bits += 8;
        }
    }

    /// Peek at the lowest `n` bits without consuming them.
    #[inline(always)]
    pub fn peek(&self, n: u32) -> u32 {
        debug_assert!(n <= 32 && n <= self.bits);
        (self.buf as u32) & ((1u32 << n) - 1)
    }

    /// Peek at the lowest `n` bits as u64 without consuming them.
    #[inline(always)]
    pub fn peek64(&self, n: u32) -> u64 {
        debug_assert!(n <= 56 && n <= self.bits);
        self.buf & ((1u64 << n) - 1)
    }

    /// Peek at bits starting at offset `skip` (skip the first `skip` bits).
    /// Returns the value of bits [skip..skip+N) but without consuming anything.
    /// Used for combined length+extra decode pattern.
    #[inline(always)]
    pub fn peek_at(&self, skip: u32) -> u32 {
        (self.buf >> skip) as u32
    }

    /// Consume `n` bits (shift them out of the buffer).
    /// In release mode, saturates to available bits (avoids UB at end-of-stream padding).
    #[inline(always)]
    pub fn consume(&mut self, n: u32) {
        debug_assert!(n <= 64);
        let n = n.min(self.bits);
        self.buf >>= n;
        self.bits -= n;
    }

    /// Consume `n` bits without bounds checking (hot loop only).
    /// Caller must guarantee `n <= self.bits`.
    #[inline(always)]
    pub unsafe fn consume_unchecked(&mut self, n: u32) {
        debug_assert!(n <= self.bits);
        self.buf >>= n;
        self.bits -= n;
    }

    /// Consume `n` bits and return their value.
    #[inline(always)]
    pub fn take(&mut self, n: u32) -> u32 {
        let v = self.peek(n);
        self.consume(n);
        v
    }

    /// Extract variable-length bits from the saved buffer value.
    /// Equivalent to `saved_buf & ((1 << n) - 1)`.
    /// On x86_64 with BMI2, uses the BZHI instruction (single uop).
    #[inline(always)]
    pub fn extract_var(value: u64, n: u32) -> u64 {
        #[cfg(target_arch = "x86_64")]
        {
            if cfg!(target_feature = "bmi2") {
                return unsafe { core::arch::x86_64::_bzhi_u64(value, n) };
            }
        }
        value & ((1u64 << n) - 1)
    }

    /// Number of valid bits remaining in the buffer.
    #[inline(always)]
    pub fn bits_remaining(&self) -> u32 {
        self.bits
    }

    /// Whether we've consumed all input AND the bit buffer is empty.
    #[inline(always)]
    pub fn is_empty(&self) -> bool {
        self.ptr >= self.end && self.bits == 0
    }

    /// Current input pointer (for fastloop bounds checking).
    #[inline(always)]
    pub fn input_ptr(&self) -> *const u8 {
        self.ptr
    }

    /// Input end pointer.
    #[inline(always)]
    pub fn input_end(&self) -> *const u8 {
        self.end
    }

    /// Align to byte boundary (discard partial byte bits).
    /// Used for stored blocks which start byte-aligned.
    #[inline(always)]
    pub fn align_to_byte(&mut self) {
        let discard = self.bits & 7;
        self.consume(discard);
    }

    /// Read a u16 from the bit buffer (byte-aligned).
    /// Used for stored block LEN/NLEN fields.
    #[inline(always)]
    pub fn take_u16(&mut self) -> u16 {
        debug_assert!(self.bits >= 16);
        let v = (self.buf as u16).to_le();
        self.consume(16);
        v
    }

    /// The raw bit buffer value (for preloading Huffman entries).
    #[inline(always)]
    pub fn raw_buf(&self) -> u64 {
        self.buf
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn basic_read() {
        let data = [0b10110100u8, 0b01101001u8, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
        let mut br = BitReader::new(&data);
        unsafe { br.refill() };
        assert!(br.bits_remaining() >= 56);
        // LSB-first: first byte 0xB4 = 10110100
        assert_eq!(br.take(4), 0b0100); // low 4 bits of 0xB4
        assert_eq!(br.take(4), 0b1011); // high 4 bits of 0xB4
        assert_eq!(br.take(8), 0b01101001); // second byte 0x69
    }

    #[test]
    fn refill_guarantees_56_bits() {
        let data = vec![0xAAu8; 64];
        let mut br = BitReader::new(&data);
        unsafe { br.refill() };
        assert!(br.bits_remaining() >= 56);
        br.consume(48);
        unsafe { br.refill() };
        assert!(br.bits_remaining() >= 56);
    }

    #[test]
    fn small_input() {
        let data = [0x42u8, 0x37];
        let mut br = BitReader::new(&data);
        unsafe { br.refill() };
        assert!(br.bits_remaining() >= 16);
        assert_eq!(br.take(8), 0x42);
        assert_eq!(br.take(8), 0x37);
    }

    #[test]
    fn align_to_byte() {
        let data = [0xFF; 8];
        let mut br = BitReader::new(&data);
        unsafe { br.refill() };
        br.consume(3); // consume 3 bits
        br.align_to_byte(); // should discard 5 more to reach byte boundary
        assert_eq!(br.bits_remaining() % 8, 0);
    }

    #[test]
    fn extract_var_matches_mask() {
        assert_eq!(BitReader::extract_var(0xDEADBEEF, 8), 0xEF);
        assert_eq!(BitReader::extract_var(0xDEADBEEF, 16), 0xBEEF);
        assert_eq!(BitReader::extract_var(0xDEADBEEF, 32), 0xDEADBEEF);
    }

    #[test]
    fn peek_does_not_consume() {
        let data = [0xAB; 8];
        let mut br = BitReader::new(&data);
        unsafe { br.refill() };
        let a = br.peek(8);
        let b = br.peek(8);
        assert_eq!(a, b);
        assert_eq!(a, 0xAB);
    }

    #[test]
    fn empty_input() {
        let data: &[u8] = &[];
        let br = BitReader::new(data);
        assert!(br.is_empty());
    }
}