safebit 0.1.0

Safe and secure bit access into integer types
Documentation
//! Bitstream implementations.

use crate::{
    error::SafebitStreamError,
    util::Endianess,
    word::{TypeCast, Word},
    word_read::WordRead,
};
use core::marker::PhantomData;

/// A bit stream over [words](Word) supplied by a specific [WordRead]
/// implementation.
pub struct Bitstream<'a, DW, E, R, End>
where
    DW: Word,
    R: WordRead<DW, E, End>,
    End: Endianess,
{
    /// [`WordRead`] implementation.
    reader: &'a mut R,
    /// Buffer word to support unaligned reads.
    buf: DW,
    /// Amount of bits already consumed in the current buffer word.
    buf_consumed_bits: usize,
    phantom1: PhantomData<&'a E>,
    phantom2: PhantomData<&'a End>,
}

impl<'a, DW, E, WR, End> Bitstream<'a, DW, E, WR, End>
where
    DW: Word,
    WR: WordRead<DW, E, End>,
    End: Endianess,
{
    /// Create a new `Bitstream` over a reader of words.
    pub fn new(reader: &'a mut WR) -> Self {
        Self {
            reader,
            buf: DW::ZERO,
            buf_consumed_bits: DW::BIT_LEN,
            phantom1: PhantomData,
            phantom2: PhantomData,
        }
    }

    /// Read a slice of bits from the stream into a single [output word](Word).
    /// If the OW is signed, perform sign extension to set all bits in OW which
    /// were not covered by the read slice.
    pub fn read<OW>(&mut self, bit_len: usize) -> Result<OW, SafebitStreamError<E>>
    where
        OW: Word,
        DW: TypeCast<OW>,
        SafebitStreamError<E>: From<E>,
    {
        if bit_len == 0 {
            return Err(SafebitStreamError::InvalidLength { bit_len });
        }

        if bit_len > OW::BIT_LEN {
            return Err(SafebitStreamError::InvalidLength { bit_len });
        }

        let mut output = OW::ZERO;
        let mut remaining_bit_len = bit_len;
        while remaining_bit_len > 0 {
            // Read a new word if necessary
            if self.buf_consumed_bits == DW::BIT_LEN {
                self.buf = self.reader.read()?;
                self.buf_consumed_bits = 0;
            }

            let buf_remaining_bits = DW::BIT_LEN - self.buf_consumed_bits;
            let subslice_len = buf_remaining_bits.min(remaining_bit_len);
            let surplus_len = DW::BIT_LEN - self.buf_consumed_bits - subslice_len;

            // Shift left to cut out previously consumed bits
            let left_shifted = self.buf << self.buf_consumed_bits;
            // Shift right to cut out bits we don't want to read
            let subslice = left_shifted >> (self.buf_consumed_bits + surplus_len);
            let sublice_ow: OW = subslice.typecast();

            remaining_bit_len -= subslice_len;
            self.buf_consumed_bits += subslice_len;

            output |= sublice_ow << remaining_bit_len;
        }

        if OW::IS_SIGNED {
            // OW might not be fully filled, so we have to fix the sign. Since
            // we initialized it with 0, we don't have to handle the positive
            // case.
            let sign_mask = OW::ONE << (bit_len - 1);
            if output & sign_mask != OW::ZERO {
                let ones: OW::UnsignedType = !OW::UnsignedType::ZERO;
                let mask = (ones >> bit_len) << bit_len; // TODO when does shift panic?
                output |= mask.typecast();
            }
        }

        Ok(output)
    }
}

#[cfg(test)]
mod tests {
    use std::{fs::File, io::BufReader};
    extern crate std;
    use crate::{bitstream::Bitstream, util::LittleEndian};

    #[test]
    fn test_safe_bit_stream() {
        let f = File::open("Cargo.toml").unwrap();
        let mut buf_reader = BufReader::new(f);
        let mut stream: Bitstream<u16, _, _, LittleEndian> = Bitstream::new(&mut buf_reader);
        let w1: u16 = stream.read(8).unwrap();
        assert_eq!(w1, 0x70);
        let w2: u16 = stream.read(8).unwrap();
        assert_eq!(w2, 0x5b);
    }
}