ruzstd 0.3.0

A decoder for the zstd compression format
Documentation
pub struct BitReader<'s> {
    idx: usize, //index counts bits already read
    source: &'s [u8],
}

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum GetBitsError {
    #[error("Cant serve this request. The reader is limited to {limit} bits, requested {num_requested_bits} bits")]
    TooManyBits {
        num_requested_bits: usize,
        limit: u8,
    },
    #[error("Can't read {requested} bits, only have {remaining} bits left")]
    NotEnoughRemainingBits { requested: usize, remaining: usize },
}

impl<'s> BitReader<'s> {
    pub fn new(source: &'s [u8]) -> BitReader<'_> {
        BitReader { idx: 0, source }
    }

    pub fn bits_left(&self) -> usize {
        self.source.len() * 8 - self.idx
    }

    pub fn bits_read(&self) -> usize {
        self.idx
    }

    pub fn return_bits(&mut self, n: usize) {
        if n > self.idx {
            panic!("Cant return this many bits");
        }
        self.idx -= n;
    }

    pub fn get_bits(&mut self, n: usize) -> Result<u64, GetBitsError> {
        if n > 64 {
            return Err(GetBitsError::TooManyBits {
                num_requested_bits: n,
                limit: 64,
            });
        }
        if self.bits_left() < n {
            return Err(GetBitsError::NotEnoughRemainingBits {
                requested: n,
                remaining: self.bits_left(),
            });
        }

        let old_idx = self.idx;

        let bits_left_in_current_byte = 8 - (self.idx % 8);
        let bits_not_needed_in_current_byte = 8 - bits_left_in_current_byte;

        //collect bits from the currently pointed to byte
        let mut value = u64::from(self.source[self.idx / 8] >> bits_not_needed_in_current_byte);

        if bits_left_in_current_byte >= n {
            //no need for fancy stuff

            //just mask all but the needed n bit
            value &= (1 << n) - 1;
            self.idx += n;
        } else {
            self.idx += bits_left_in_current_byte;

            //n spans over multiple bytes
            let full_bytes_needed = (n - bits_left_in_current_byte) / 8;
            let bits_in_last_byte_needed = n - bits_left_in_current_byte - full_bytes_needed * 8;

            assert!(
                bits_left_in_current_byte + full_bytes_needed * 8 + bits_in_last_byte_needed == n
            );

            let mut bit_shift = bits_left_in_current_byte; //this many bits are already set in value

            assert!(self.idx % 8 == 0);

            //collect full bytes
            for _ in 0..full_bytes_needed {
                value |= u64::from(self.source[self.idx / 8]) << bit_shift;
                self.idx += 8;
                bit_shift += 8;
            }

            assert!(n - bit_shift == bits_in_last_byte_needed);

            if bits_in_last_byte_needed > 0 {
                let val_las_byte =
                    u64::from(self.source[self.idx / 8]) & ((1 << bits_in_last_byte_needed) - 1);
                value |= val_las_byte << bit_shift;
                self.idx += bits_in_last_byte_needed;
            }
        }

        assert!(self.idx == old_idx + n);

        Ok(value)
    }

    pub fn reset(&mut self, new_source: &'s [u8]) {
        self.idx = 0;
        self.source = new_source;
    }
}