1use {
2 super::{
3 bits_errors::{BitError, BitErrorValue},
4 bytes_writer::BytesWriter,
5 },
6 bytes::BytesMut,
7};
8
9pub struct BitsWriter {
10 writer: BytesWriter,
11 cur_byte: u8,
12 cur_bit_num: u8,
13}
14
15impl BitsWriter {
16 pub fn new(writer: BytesWriter) -> Self {
17 Self {
18 writer,
19 cur_byte: 0,
20 cur_bit_num: 0,
21 }
22 }
23
24 pub fn write_bytes(&mut self, data: BytesMut) -> Result<(), BitError> {
25 self.writer.write(&data[..])?;
26 Ok(())
27 }
28
29 pub fn write_bit(&mut self, b: u8) -> Result<(), BitError> {
30 self.cur_byte |= b << (7 - self.cur_bit_num);
31 self.cur_bit_num += 1;
32
33 if self.cur_bit_num == 8 {
34 self.writer.write_u8(self.cur_byte)?;
35 self.cur_bit_num = 0;
36 self.cur_byte = 0;
37 }
38
39 Ok(())
40 }
41
42 pub fn write_8bit(&mut self, b: u8) -> Result<(), BitError> {
43 if self.cur_bit_num != 0 {
44 return Err(BitError {
45 value: BitErrorValue::CannotWrite8Bit,
46 });
47 }
48
49 self.writer.write_u8(b)?;
50 Ok(())
51 }
52
53 fn flush(&mut self) -> Result<(), BitError> {
54 if self.cur_bit_num == 8 {
55 self.writer.write_u8(self.cur_byte)?;
56 self.cur_bit_num = 0;
57 self.cur_byte = 0;
58 } else {
59 log::trace!("cannot flush: {}", self.cur_bit_num);
60 }
61
62 Ok(())
63 }
64
65 pub fn write_n_bits(&mut self, data: u64, bit_num: usize) -> Result<(), BitError> {
67 if bit_num > 64 {
68 return Err(BitError {
69 value: BitErrorValue::TooBig,
70 });
71 }
72 let mut bit_num_mut = bit_num;
73 let mut data_mut = data;
74
75 data_mut <<= 64 - bit_num;
77 self.cur_byte |= (data_mut >> (56 + self.cur_bit_num)) as u8;
78
79 let cur_byte_left_bit_num = 8 - self.cur_bit_num as usize;
80 if bit_num_mut >= cur_byte_left_bit_num {
81 data_mut <<= cur_byte_left_bit_num;
83 bit_num_mut -= cur_byte_left_bit_num;
84 self.cur_bit_num = 8;
85 self.flush()?;
86 } else {
87 self.cur_bit_num += bit_num_mut as u8;
89 return Ok(());
90 }
91
92 while bit_num_mut > 0 {
93 self.cur_byte = (data_mut >> 56) as u8;
94
95 if bit_num_mut > 8 {
96 self.cur_bit_num = 8;
97 self.flush()?;
98 data_mut <<= 8;
99 bit_num_mut -= 8;
100 } else {
101 self.cur_bit_num = bit_num_mut as u8;
102 break;
103 }
104 }
105
106 Ok(())
107 }
108
109 pub fn bits_aligment_8(&mut self) -> Result<(), BitError> {
110 self.cur_bit_num = 8;
111 self.flush()?;
112 Ok(())
113 }
114
115 pub fn get_current_bytes(&self) -> BytesMut {
116 self.writer.get_current_bytes()
117 }
118
119 pub fn len(&self) -> usize {
120 self.writer.len() * 8 + self.cur_bit_num as usize
121 }
122 pub fn is_empty(&self) -> bool {
123 self.len() == 0
124 }
125}
126
127#[cfg(test)]
128mod tests {
129
130 use super::BitsWriter;
131 use super::BytesWriter;
132
133
134 #[test]
135 fn test_write_bit() {
136 let bytes_writer = BytesWriter::new();
137 let mut bit_writer = BitsWriter::new(bytes_writer);
138
139 bit_writer.write_bit(0).unwrap();
140 bit_writer.write_bit(0).unwrap();
141 bit_writer.write_bit(0).unwrap();
142 bit_writer.write_bit(0).unwrap();
143
144 bit_writer.write_bit(0).unwrap();
145 bit_writer.write_bit(0).unwrap();
146 bit_writer.write_bit(1).unwrap();
147 bit_writer.write_bit(0).unwrap();
148
149 let byte = bit_writer.get_current_bytes();
150 assert!(byte.to_vec()[0] == 0x2);
151
152 bit_writer.write_bit(1).unwrap();
153 bit_writer.write_bit(1).unwrap();
154
155 println!("=={}=={}==", bit_writer.cur_bit_num, bit_writer.cur_byte);
156 assert!(bit_writer.cur_bit_num == 2);
157 assert!(bit_writer.cur_byte == 0xC0); }
159
160 #[test]
161 fn test_write_n_bits() {
162 let bytes_writer = BytesWriter::new();
163 let mut bit_writer = BitsWriter::new(bytes_writer);
164
165 bit_writer.write_bit(1).unwrap();
166 bit_writer.write_bit(1).unwrap();
167 bit_writer.write_bit(0).unwrap();
168
169 bit_writer.write_n_bits(0x03, 7).unwrap();
170
171 let byte = bit_writer.get_current_bytes();
172
173 println!("=={}=={}==", bit_writer.cur_bit_num, bit_writer.cur_byte);
176 println!("=={}==", byte.to_vec()[0]);
177
178 assert!(byte.to_vec()[0] == 0xC0); assert!(bit_writer.cur_bit_num == 2);
181 assert!(bit_writer.cur_byte == 0xC0); }
183
184 #[test]
185 fn test_bits_aligment_8() {
186 let bytes_writer = BytesWriter::new();
187 let mut bit_writer = BitsWriter::new(bytes_writer);
188
189 bit_writer.write_bit(1).unwrap();
190 bit_writer.write_bit(1).unwrap();
191 bit_writer.write_bit(0).unwrap();
192
193 bit_writer.bits_aligment_8().unwrap();
194
195 let byte = bit_writer.get_current_bytes();
196 assert!(byte.to_vec()[0] == 0xC0); bit_writer.write_bit(1).unwrap();
199 bit_writer.write_bit(1).unwrap();
200 bit_writer.write_bit(0).unwrap();
201
202 assert!(bit_writer.cur_bit_num == 3);
203 assert!(bit_writer.cur_byte == 0xC0); }
205}