bytesio/
bits_writer.rs

1use {
2    super::{
3        bits_errors::{BitError, BitErrorValue},
4        bytes_writer::BytesWriter,
5    },
6    bytes::BytesMut,
7};
8
9pub struct BitsWriter {
10    writer: BytesWriter,
11    cur_byte: u8,
12    cur_bit_num: u8,
13}
14
15impl BitsWriter {
16    pub fn new(writer: BytesWriter) -> Self {
17        Self {
18            writer,
19            cur_byte: 0,
20            cur_bit_num: 0,
21        }
22    }
23
24    pub fn write_bytes(&mut self, data: BytesMut) -> Result<(), BitError> {
25        self.writer.write(&data[..])?;
26        Ok(())
27    }
28
29    pub fn write_bit(&mut self, b: u8) -> Result<(), BitError> {
30        self.cur_byte |= b << (7 - self.cur_bit_num);
31        self.cur_bit_num += 1;
32
33        if self.cur_bit_num == 8 {
34            self.writer.write_u8(self.cur_byte)?;
35            self.cur_bit_num = 0;
36            self.cur_byte = 0;
37        }
38
39        Ok(())
40    }
41
42    pub fn write_8bit(&mut self, b: u8) -> Result<(), BitError> {
43        if self.cur_bit_num != 0 {
44            return Err(BitError {
45                value: BitErrorValue::CannotWrite8Bit,
46            });
47        }
48
49        self.writer.write_u8(b)?;
50        Ok(())
51    }
52
53    fn flush(&mut self) -> Result<(), BitError> {
54        if self.cur_bit_num == 8 {
55            self.writer.write_u8(self.cur_byte)?;
56            self.cur_bit_num = 0;
57            self.cur_byte = 0;
58        } else {
59            log::trace!("cannot flush: {}", self.cur_bit_num);
60        }
61
62        Ok(())
63    }
64
65    // 0x02 4
66    pub fn write_n_bits(&mut self, data: u64, bit_num: usize) -> Result<(), BitError> {
67        if bit_num > 64 {
68            return Err(BitError {
69                value: BitErrorValue::TooBig,
70            });
71        }
72        let mut bit_num_mut = bit_num;
73        let mut data_mut = data;
74
75        //read left bits  for current byte
76        data_mut <<= 64 - bit_num;
77        self.cur_byte |= (data_mut >> (56 + self.cur_bit_num)) as u8;
78
79        let cur_byte_left_bit_num = 8 - self.cur_bit_num as usize;
80        if bit_num_mut >= cur_byte_left_bit_num {
81            // the bits for current byte is full, then flush
82            data_mut <<= cur_byte_left_bit_num;
83            bit_num_mut -= cur_byte_left_bit_num;
84            self.cur_bit_num = 8;
85            self.flush()?;
86        } else {
87            // not full, only update bit num
88            self.cur_bit_num += bit_num_mut as u8;
89            return Ok(());
90        }
91
92        while bit_num_mut > 0 {
93            self.cur_byte = (data_mut >> 56) as u8;
94
95            if bit_num_mut > 8 {
96                self.cur_bit_num = 8;
97                self.flush()?;
98                data_mut <<= 8;
99                bit_num_mut -= 8;
100            } else {
101                self.cur_bit_num = bit_num_mut as u8;
102                break;
103            }
104        }
105
106        Ok(())
107    }
108
109    pub fn bits_aligment_8(&mut self) -> Result<(), BitError> {
110        self.cur_bit_num = 8;
111        self.flush()?;
112        Ok(())
113    }
114
115    pub fn get_current_bytes(&self) -> BytesMut {
116        self.writer.get_current_bytes()
117    }
118
119    pub fn len(&self) -> usize {
120        self.writer.len() * 8 + self.cur_bit_num as usize
121    }
122    pub fn is_empty(&self) -> bool {
123        self.len() == 0
124    }
125}
126
127#[cfg(test)]
128mod tests {
129
130    use super::BitsWriter;
131    use super::BytesWriter;
132    
133
134    #[test]
135    fn test_write_bit() {
136        let bytes_writer = BytesWriter::new();
137        let mut bit_writer = BitsWriter::new(bytes_writer);
138
139        bit_writer.write_bit(0).unwrap();
140        bit_writer.write_bit(0).unwrap();
141        bit_writer.write_bit(0).unwrap();
142        bit_writer.write_bit(0).unwrap();
143
144        bit_writer.write_bit(0).unwrap();
145        bit_writer.write_bit(0).unwrap();
146        bit_writer.write_bit(1).unwrap();
147        bit_writer.write_bit(0).unwrap();
148
149        let byte = bit_writer.get_current_bytes();
150        assert!(byte.to_vec()[0] == 0x2);
151
152        bit_writer.write_bit(1).unwrap();
153        bit_writer.write_bit(1).unwrap();
154
155        println!("=={}=={}==", bit_writer.cur_bit_num, bit_writer.cur_byte);
156        assert!(bit_writer.cur_bit_num == 2);
157        assert!(bit_writer.cur_byte == 0xC0); //0x11000000
158    }
159
160    #[test]
161    fn test_write_n_bits() {
162        let bytes_writer = BytesWriter::new();
163        let mut bit_writer = BitsWriter::new(bytes_writer);
164
165        bit_writer.write_bit(1).unwrap();
166        bit_writer.write_bit(1).unwrap();
167        bit_writer.write_bit(0).unwrap();
168
169        bit_writer.write_n_bits(0x03, 7).unwrap();
170
171        let byte = bit_writer.get_current_bytes();
172
173        //0x11000000 0x11
174
175        println!("=={}=={}==", bit_writer.cur_bit_num, bit_writer.cur_byte);
176        println!("=={}==", byte.to_vec()[0]);
177
178        assert!(byte.to_vec()[0] == 0xC0); //0x11000000
179
180        assert!(bit_writer.cur_bit_num == 2);
181        assert!(bit_writer.cur_byte == 0xC0); //0x11000000
182    }
183
184    #[test]
185    fn test_bits_aligment_8() {
186        let bytes_writer = BytesWriter::new();
187        let mut bit_writer = BitsWriter::new(bytes_writer);
188
189        bit_writer.write_bit(1).unwrap();
190        bit_writer.write_bit(1).unwrap();
191        bit_writer.write_bit(0).unwrap();
192
193        bit_writer.bits_aligment_8().unwrap();
194
195        let byte = bit_writer.get_current_bytes();
196        assert!(byte.to_vec()[0] == 0xC0); //0x11000000
197
198        bit_writer.write_bit(1).unwrap();
199        bit_writer.write_bit(1).unwrap();
200        bit_writer.write_bit(0).unwrap();
201
202        assert!(bit_writer.cur_bit_num == 3);
203        assert!(bit_writer.cur_byte == 0xC0); //0x11000000
204    }
205}