hayro_syntax/
bit_reader.rs

1//! A bit reader that supports reading numbers from a bit stream, with a number of bits
2//! up to 32.
3
4use log::warn;
5use smallvec::{SmallVec, smallvec};
6use std::fmt::Debug;
7
8/// A bit size.
9#[derive(PartialEq, Eq, Debug, Clone, Copy)]
10pub struct BitSize(u8);
11
12impl BitSize {
13    /// Create a new `BitSize`. Returns `None` if the number is bigger than 32.
14    pub fn from_u8(value: u8) -> Option<Self> {
15        if value > 32 { None } else { Some(Self(value)) }
16    }
17
18    /// Return the number of bits.
19    pub fn bits(&self) -> usize {
20        self.0 as usize
21    }
22
23    /// Return the bit mask.
24    pub fn mask(&self) -> u32 {
25        ((1u64 << self.0 as u64) - 1) as u32
26    }
27}
28
29/// A bit reader.
30pub struct BitReader<'a> {
31    data: &'a [u8],
32    cur_pos: usize,
33}
34
35impl<'a> BitReader<'a> {
36    /// Create a new bit reader.
37    pub fn new(data: &'a [u8]) -> Self {
38        Self::new_with(data, 0)
39    }
40
41    /// Create a new bit reader, and start at a specific bit offset.
42    pub fn new_with(data: &'a [u8], cur_pos: usize) -> Self {
43        Self { data, cur_pos }
44    }
45
46    /// Align the reader to the next byte boundary.
47    #[inline]
48    pub fn align(&mut self) {
49        let bit_pos = self.bit_pos();
50
51        if !bit_pos.is_multiple_of(8) {
52            self.cur_pos += 8 - bit_pos;
53        }
54    }
55
56    /// Read the given number of bits from the byte stream.
57    #[inline(always)]
58    pub fn read(&mut self, bit_size: BitSize) -> Option<u32> {
59        let byte_pos = self.byte_pos();
60
61        if bit_size.0 > 32 || byte_pos >= self.data.len() {
62            return None;
63        }
64
65        let item = match bit_size.0 {
66            8 => {
67                let item = self.data[byte_pos] as u32;
68                self.cur_pos += 8;
69
70                Some(item)
71            }
72            0..=32 => {
73                let bit_pos = self.bit_pos();
74                let end_byte_pos = (bit_pos + bit_size.0 as usize - 1) / 8;
75                let mut read = [0u8; 8];
76
77                for (i, r) in read.iter_mut().enumerate().take(end_byte_pos + 1) {
78                    *r = *self.data.get(byte_pos + i)?;
79                }
80
81                let item = (u64::from_be_bytes(read) >> (64 - bit_pos - bit_size.0 as usize))
82                    as u32
83                    & bit_size.mask();
84                self.cur_pos += bit_size.0 as usize;
85
86                Some(item)
87            }
88            _ => unreachable!(),
89        }?;
90
91        Some(item)
92    }
93
94    #[inline]
95    fn byte_pos(&self) -> usize {
96        self.cur_pos / 8
97    }
98
99    #[inline]
100    fn bit_pos(&self) -> usize {
101        self.cur_pos % 8
102    }
103}
104
105#[derive(Debug)]
106pub(crate) struct BitWriter<'a> {
107    data: &'a mut [u8],
108    cur_pos: usize,
109    bit_size: BitSize,
110}
111
112impl<'a> BitWriter<'a> {
113    pub(crate) fn new(data: &'a mut [u8], bit_size: BitSize) -> Option<Self> {
114        if !matches!(bit_size.0, 1 | 2 | 4 | 8 | 16) {
115            return None;
116        }
117
118        Some(Self {
119            data,
120            bit_size,
121            cur_pos: 0,
122        })
123    }
124
125    pub(crate) fn split_off(self) -> (&'a [u8], BitWriter<'a>) {
126        // Assumes that we are currently aligned to a byte boundary!
127        let (left, right) = self.data.split_at_mut(self.cur_pos / 8);
128        (
129            left,
130            BitWriter {
131                data: right,
132                cur_pos: 0,
133                bit_size: self.bit_size,
134            },
135        )
136    }
137
138    /// Align the writer to the next byte boundary.
139    #[cfg(feature = "jpeg2000")]
140    pub(crate) fn align(&mut self) {
141        let bit_pos = self.bit_pos();
142
143        if !bit_pos.is_multiple_of(8) {
144            self.cur_pos += 8 - bit_pos;
145        }
146    }
147
148    pub(crate) fn cur_pos(&self) -> usize {
149        self.cur_pos
150    }
151
152    pub(crate) fn get_data(&self) -> &[u8] {
153        self.data
154    }
155
156    fn byte_pos(&self) -> usize {
157        self.cur_pos / 8
158    }
159
160    fn bit_pos(&self) -> usize {
161        self.cur_pos % 8
162    }
163
164    pub(crate) fn write(&mut self, val: u16) -> Option<()> {
165        let byte_pos = self.byte_pos();
166        let bit_size = self.bit_size;
167
168        match bit_size.0 {
169            1 | 2 | 4 => {
170                let bit_pos = self.bit_pos();
171
172                let base = self.data.get(byte_pos)?;
173                let shift = 8 - self.bit_size.bits() - bit_pos;
174                let item = ((val & self.bit_size.mask() as u16) as u8) << shift;
175
176                *(self.data.get_mut(byte_pos)?) = *base | item;
177                self.cur_pos += bit_size.bits();
178            }
179            8 => {
180                *(self.data.get_mut(byte_pos)?) = val as u8;
181                self.cur_pos += 8;
182            }
183            16 => {
184                self.data
185                    .get_mut(byte_pos..(byte_pos + 2))?
186                    .copy_from_slice(&val.to_be_bytes());
187                self.cur_pos += 16;
188            }
189            _ => unreachable!(),
190        }
191
192        Some(())
193    }
194}
195
196pub(crate) struct BitChunks<'a> {
197    reader: BitReader<'a>,
198    bit_size: BitSize,
199    chunk_len: usize,
200}
201
202impl<'a> BitChunks<'a> {
203    pub(crate) fn new(data: &'a [u8], bit_size: BitSize, chunk_len: usize) -> Option<Self> {
204        if bit_size.0 > 16 {
205            warn!("BitChunks doesn't support working with bit sizes > 16.");
206
207            return None;
208        }
209
210        let reader = BitReader::new(data);
211
212        Some(Self {
213            reader,
214            bit_size,
215            chunk_len,
216        })
217    }
218}
219
220impl Iterator for BitChunks<'_> {
221    type Item = BitChunk;
222
223    fn next(&mut self) -> Option<Self::Item> {
224        let mut bits = SmallVec::new();
225
226        for _ in 0..self.chunk_len {
227            bits.push(self.reader.read(self.bit_size)? as u16);
228        }
229
230        Some(BitChunk { bits })
231    }
232}
233
234#[derive(Debug, Clone)]
235pub(crate) struct BitChunk {
236    bits: SmallVec<[u16; 4]>,
237}
238
239impl BitChunk {
240    pub(crate) fn iter(&self) -> impl Iterator<Item = u16> + '_ {
241        self.bits.iter().copied()
242    }
243
244    pub(crate) fn new(val: u8, count: usize) -> Self {
245        Self {
246            bits: smallvec![val as u16; count],
247        }
248    }
249
250    pub(crate) fn from_reader(
251        bit_reader: &mut BitReader,
252        bit_size: BitSize,
253        chunk_len: usize,
254    ) -> Option<Self> {
255        if bit_size.0 > 16 {
256            warn!("BitChunk doesn't support working with bit sizes > 16.");
257
258            return None;
259        }
260
261        let mut bits = SmallVec::new();
262
263        for _ in 0..chunk_len {
264            bits.push(bit_reader.read(bit_size)? as u16);
265        }
266
267        Some(BitChunk { bits })
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    const BS1: BitSize = BitSize(1);
276    const BS2: BitSize = BitSize(2);
277    const BS4: BitSize = BitSize(4);
278    const BS8: BitSize = BitSize(8);
279    const BS16: BitSize = BitSize(16);
280
281    #[test]
282    fn bit_reader_16() {
283        let data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
284        let mut reader = BitReader::new(&data);
285        assert_eq!(
286            reader.read(BS16).unwrap() as u16,
287            u16::from_be_bytes([0x01, 0x02])
288        );
289        assert_eq!(
290            reader.read(BS16).unwrap() as u16,
291            u16::from_be_bytes([0x03, 0x04])
292        );
293        assert_eq!(
294            reader.read(BS16).unwrap() as u16,
295            u16::from_be_bytes([0x05, 0x06])
296        );
297    }
298
299    #[test]
300    fn bit_writer_16() {
301        let mut buf = vec![0u8; 6];
302        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(16).unwrap()).unwrap();
303        writer.write(u16::from_be_bytes([0x01, 0x02])).unwrap();
304        writer.write(u16::from_be_bytes([0x03, 0x04])).unwrap();
305        writer.write(u16::from_be_bytes([0x05, 0x06])).unwrap();
306
307        assert_eq!(buf, [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]);
308    }
309
310    #[test]
311    fn bit_reader_12() {
312        let data = [0b10011000, 0b00011111, 0b10101001, 0b11101001, 0b00011010];
313        let mut reader = BitReader::new(&data);
314        assert_eq!(
315            reader.read(BitSize::from_u8(12).unwrap()).unwrap(),
316            0b100110000001
317        );
318        assert_eq!(
319            reader.read(BitSize::from_u8(12).unwrap()).unwrap(),
320            0b111110101001
321        );
322        assert_eq!(
323            reader.read(BitSize::from_u8(12).unwrap()).unwrap(),
324            0b111010010001
325        );
326    }
327
328    #[test]
329    fn bit_reader_9() {
330        let data = [0b10011000, 0b00011111, 0b10101001, 0b11101001, 0b00011010];
331        let mut reader = BitReader::new(&data);
332        assert_eq!(
333            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
334            0b100110000
335        );
336        assert_eq!(
337            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
338            0b001111110
339        );
340        assert_eq!(
341            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
342            0b101001111
343        );
344        assert_eq!(
345            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
346            0b010010001
347        );
348    }
349
350    #[test]
351    fn bit_writer_8() {
352        let mut buf = vec![0u8; 3];
353        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(8).unwrap()).unwrap();
354        writer.write(0x01).unwrap();
355        writer.write(0x02).unwrap();
356        writer.write(0x03).unwrap();
357
358        assert_eq!(buf, [0x01, 0x02, 0x03]);
359    }
360
361    #[test]
362    fn bit_reader_8() {
363        let data = [0x01, 0x02, 0x03];
364        let mut reader = BitReader::new(&data);
365        assert_eq!(reader.read(BS8).unwrap(), 0x01);
366        assert_eq!(reader.read(BS8).unwrap(), 0x02);
367        assert_eq!(reader.read(BS8).unwrap(), 0x03);
368    }
369
370    #[test]
371    fn bit_writer_4() {
372        let mut buf = vec![0u8; 3];
373        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(4).unwrap()).unwrap();
374        writer.write(0b1001).unwrap();
375        writer.write(0b1000).unwrap();
376        writer.write(0b0001).unwrap();
377        writer.write(0b1111).unwrap();
378        writer.write(0b1010).unwrap();
379        writer.write(0b1001).unwrap();
380
381        assert_eq!(buf, [0b10011000, 0b00011111, 0b10101001]);
382    }
383
384    #[test]
385    fn bit_reader_4() {
386        let data = [0b10011000, 0b00011111, 0b10101001];
387        let mut reader = BitReader::new(&data);
388        assert_eq!(reader.read(BS4).unwrap(), 0b1001);
389        assert_eq!(reader.read(BS4).unwrap(), 0b1000);
390        assert_eq!(reader.read(BS4).unwrap(), 0b0001);
391        assert_eq!(reader.read(BS4).unwrap(), 0b1111);
392        assert_eq!(reader.read(BS4).unwrap(), 0b1010);
393        assert_eq!(reader.read(BS4).unwrap(), 0b1001);
394    }
395
396    #[test]
397    fn bit_writer_2() {
398        let mut buf = vec![0u8; 2];
399        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(2).unwrap()).unwrap();
400        writer.write(0b10).unwrap();
401        writer.write(0b01).unwrap();
402        writer.write(0b10).unwrap();
403        writer.write(0b00).unwrap();
404        writer.write(0b00).unwrap();
405        writer.write(0b01).unwrap();
406        writer.write(0b00).unwrap();
407        writer.write(0b00).unwrap();
408
409        assert_eq!(buf, [0b10011000, 0b00010000]);
410    }
411
412    #[test]
413    fn bit_reader_2() {
414        let data = [0b10011000, 0b00010000];
415        let mut reader = BitReader::new(&data);
416        assert_eq!(reader.read(BS2).unwrap(), 0b10);
417        assert_eq!(reader.read(BS2).unwrap(), 0b01);
418        assert_eq!(reader.read(BS2).unwrap(), 0b10);
419        assert_eq!(reader.read(BS2).unwrap(), 0b00);
420        assert_eq!(reader.read(BS2).unwrap(), 0b00);
421        assert_eq!(reader.read(BS2).unwrap(), 0b01);
422        assert_eq!(reader.read(BS2).unwrap(), 0b00);
423        assert_eq!(reader.read(BS2).unwrap(), 0b00);
424    }
425
426    #[test]
427    fn bit_writer_1() {
428        let mut buf = vec![0u8; 2];
429        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(1).unwrap()).unwrap();
430        writer.write(0b1).unwrap();
431        writer.write(0b0).unwrap();
432        writer.write(0b0).unwrap();
433        writer.write(0b1).unwrap();
434        writer.write(0b1).unwrap();
435        writer.write(0b0).unwrap();
436        writer.write(0b0).unwrap();
437        writer.write(0b0).unwrap();
438
439        writer.write(0b0).unwrap();
440        writer.write(0b0).unwrap();
441        writer.write(0b0).unwrap();
442        writer.write(0b1).unwrap();
443        writer.write(0b0).unwrap();
444        writer.write(0b0).unwrap();
445        writer.write(0b0).unwrap();
446        writer.write(0b0).unwrap();
447
448        assert_eq!(buf, [0b10011000, 0b00010000]);
449    }
450
451    #[test]
452    fn bit_reader_1() {
453        let data = [0b10011000, 0b00010000];
454        let mut reader = BitReader::new(&data);
455        assert_eq!(reader.read(BS1).unwrap(), 0b1);
456        assert_eq!(reader.read(BS1).unwrap(), 0b0);
457        assert_eq!(reader.read(BS1).unwrap(), 0b0);
458        assert_eq!(reader.read(BS1).unwrap(), 0b1);
459        assert_eq!(reader.read(BS1).unwrap(), 0b1);
460        assert_eq!(reader.read(BS1).unwrap(), 0b0);
461        assert_eq!(reader.read(BS1).unwrap(), 0b0);
462        assert_eq!(reader.read(BS1).unwrap(), 0b0);
463
464        assert_eq!(reader.read(BS1).unwrap(), 0b0);
465        assert_eq!(reader.read(BS1).unwrap(), 0b0);
466        assert_eq!(reader.read(BS1).unwrap(), 0b0);
467        assert_eq!(reader.read(BS1).unwrap(), 0b1);
468        assert_eq!(reader.read(BS1).unwrap(), 0b0);
469        assert_eq!(reader.read(BS1).unwrap(), 0b0);
470        assert_eq!(reader.read(BS1).unwrap(), 0b0);
471        assert_eq!(reader.read(BS1).unwrap(), 0b0);
472    }
473
474    #[test]
475    fn bit_reader_align() {
476        let data = [0b10011000, 0b00010000];
477        let mut reader = BitReader::new(&data);
478        assert_eq!(reader.read(BS1).unwrap(), 0b1);
479        assert_eq!(reader.read(BS1).unwrap(), 0b0);
480        assert_eq!(reader.read(BS1).unwrap(), 0b0);
481        assert_eq!(reader.read(BS1).unwrap(), 0b1);
482        reader.align();
483
484        assert_eq!(reader.read(BS1).unwrap(), 0b0);
485        assert_eq!(reader.read(BS1).unwrap(), 0b0);
486        assert_eq!(reader.read(BS1).unwrap(), 0b0);
487        assert_eq!(reader.read(BS1).unwrap(), 0b1);
488        assert_eq!(reader.read(BS1).unwrap(), 0b0);
489        assert_eq!(reader.read(BS1).unwrap(), 0b0);
490        assert_eq!(reader.read(BS1).unwrap(), 0b0);
491        assert_eq!(reader.read(BS1).unwrap(), 0b0);
492    }
493
494    #[test]
495    fn bit_reader_chunks() {
496        let data = [0b10011000, 0b00010000];
497        let mut reader = BitChunks::new(&data, BitSize::from_u8(1).unwrap(), 3).unwrap();
498        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b1, 0b0, 0b0]);
499        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b1, 0b1, 0b0]);
500        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b0, 0b0, 0b0]);
501        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b0, 0b0, 0b1]);
502        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b0, 0b0, 0b0]);
503    }
504
505    #[test]
506    fn bit_reader_varying_bit_sizes() {
507        let data = [0b10011000, 0b00011111, 0b10101001];
508        let mut reader = BitReader::new(&data);
509        assert_eq!(reader.read(BS4).unwrap(), 0b1001);
510        assert_eq!(reader.read(BS1).unwrap(), 0b1);
511        assert_eq!(reader.read(BS4).unwrap(), 0b0000);
512        assert_eq!(reader.read(BitSize::from_u8(5).unwrap()).unwrap(), 0b00111);
513        assert_eq!(reader.read(BS1).unwrap(), 0b1);
514        assert_eq!(reader.read(BS2).unwrap(), 0b11);
515        assert_eq!(
516            reader.read(BitSize::from_u8(7).unwrap()).unwrap(),
517            0b0101001
518        );
519    }
520}