Documentation
use std::io::{self, Write};

const DEFAULT_BUFFER_SIZE: usize = 1023;

// bit writer (MSB to LSB)
// TODO: make it abstract over BitStorage
pub struct Writer<W> {
    buffer: Box<[u8]>,
    size: usize, // in bits
    writer: W
}

impl<W> Writer<W> where W: Write {
    pub fn new(writer: W) -> Self {
        Self::with_capacity(DEFAULT_BUFFER_SIZE, writer)
    }

    pub fn with_capacity(capacity: usize, writer: W) -> Self {
        Writer {
            buffer: vec![0; capacity + 1].into_boxed_slice(),
            size: 0,
            writer
        }
    }

    pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
        let free_bits = self.buffer.len() * 8 - self.size;
        if free_bits == 0 {
            self.flush()?;
        }
        
        if bit {
            let byte_index = self.size / 8;
            let bit_index = 7 - self.size % 8;
            self.buffer[byte_index] |= 1 << bit_index;
        }

        self.size += 1; 
        Ok(())
    }

    // write |count| bits of |bits| from MSB to LSB
    // TODO: optimize
    // TODO: abstract over |bits|
    // TODO: use byteorder and just write in BE order into 
    //       buffer when size_of::<Bits>() > 1 and buffer is aligned
    pub fn write_bits(&mut self, bits: u8, count: usize) -> io::Result<()> {
        for i in 0..count {
            let bit = bits & (1 << (7 - i)) != 0;
            self.write_bit(bit)?
        }

        Ok(())
    }

    // flush writer (byte-level)
    // the remaining bits of last byte will be 0
    pub fn flush(&mut self) -> io::Result<()> {
        let bytes = (self.size + 7) / 8;
        self.writer.write_all(&self.buffer[..bytes])?;
        self.reset(bytes);
        Ok(())
    }

    fn reset(&mut self, bytes: usize) {
        // reset buffer
        unsafe {
            ::std::ptr::write_bytes(self.buffer.as_mut_ptr(), 0, bytes);
        }
        self.size = 0;
    }

    pub fn finish(mut self) -> io::Result<W> {
        self.flush()?;
        Ok(self.writer)
    }
}


#[cfg(test)]
mod tests {
    use super::Writer;

    #[test]
    fn within_byte() {
        let mut writer = Writer::new(Vec::new());
        writer.write_bits(0b10100000, 4).unwrap();
        writer.write_bits(0b01010000, 4).unwrap();
        let buffer = writer.finish().unwrap();
        assert_eq!(&buffer, &[0b10100101]);
    }

    #[test]
    fn overlapping() {
        let mut writer = Writer::new(Vec::new());
        writer.write_bits(0b10101000, 6).unwrap();
        writer.write_bits(0b10101000, 6).unwrap();
        let buffer = writer.finish().unwrap();
        assert_eq!(&buffer, &[0b10101010, 0b10100000]);
    }

    #[test]
    fn buffer_overflow() {
        let mut writer = Writer::with_capacity(0, Vec::new());
        writer.write_bits(0b10101000, 6).unwrap();
        writer.write_bits(0b10101000, 6).unwrap();
        let buffer = writer.finish().unwrap();
        assert_eq!(&buffer, &[0b10101010, 0b10100000]);
    }
}