hayro_syntax/
bit_reader.rs

1//! A bit reader that supports reading numbers from a bit stream, with a number of bits
2//! up to 32.
3
4use log::warn;
5use smallvec::{SmallVec, smallvec};
6use std::fmt::Debug;
7
8/// A bit size.
9#[derive(PartialEq, Eq, Debug, Clone, Copy)]
10pub struct BitSize(u8);
11
12impl BitSize {
13    /// Create a new `BitSize`. Returns `None` if the number is bigger than 32.
14    pub fn from_u8(value: u8) -> Option<Self> {
15        if value > 32 { None } else { Some(Self(value)) }
16    }
17
18    /// Return the number of bits.
19    pub fn bits(&self) -> usize {
20        self.0 as usize
21    }
22
23    /// Return the bit mask.
24    pub fn mask(&self) -> u32 {
25        ((1u64 << self.0 as u64) - 1) as u32
26    }
27}
28
29/// A bit reader.
30pub struct BitReader<'a> {
31    data: &'a [u8],
32    cur_pos: usize,
33}
34
35impl<'a> BitReader<'a> {
36    /// Create a new bit reader.
37    pub fn new(data: &'a [u8]) -> Self {
38        Self::new_with(data, 0)
39    }
40
41    /// Create a new bit reader, and start at a specific bit offset.
42    pub fn new_with(data: &'a [u8], cur_pos: usize) -> Self {
43        Self { data, cur_pos }
44    }
45
46    /// Align the reader to the next byte boundary.
47    pub fn align(&mut self) {
48        let bit_pos = self.bit_pos();
49
50        if bit_pos % 8 != 0 {
51            self.cur_pos += 8 - bit_pos;
52        }
53    }
54
55    /// Read the given number of bits from the byte stream.
56    pub fn read(&mut self, bit_size: BitSize) -> Option<u32> {
57        let byte_pos = self.byte_pos();
58
59        if bit_size.0 > 32 || byte_pos >= self.data.len() {
60            return None;
61        }
62
63        let item = match bit_size.0 {
64            8 => {
65                let item = self.data[byte_pos] as u32;
66                self.cur_pos += 8;
67
68                Some(item)
69            }
70            0..=32 => {
71                let bit_pos = self.bit_pos();
72                let end_byte_pos = (bit_pos + bit_size.0 as usize - 1) / 8;
73                let mut read = [0u8; 8];
74
75                for (i, r) in read.iter_mut().enumerate().take(end_byte_pos + 1) {
76                    *r = *self.data.get(byte_pos + i)?;
77                }
78
79                let item = (u64::from_be_bytes(read) >> (64 - bit_pos - bit_size.0 as usize))
80                    as u32
81                    & bit_size.mask();
82                self.cur_pos += bit_size.0 as usize;
83
84                Some(item)
85            }
86            _ => unreachable!(),
87        }?;
88
89        Some(item)
90    }
91
92    fn byte_pos(&self) -> usize {
93        self.cur_pos / 8
94    }
95
96    fn bit_pos(&self) -> usize {
97        self.cur_pos % 8
98    }
99}
100
101#[derive(Debug)]
102pub(crate) struct BitWriter<'a> {
103    data: &'a mut [u8],
104    cur_pos: usize,
105    bit_size: BitSize,
106}
107
108impl<'a> BitWriter<'a> {
109    pub(crate) fn new(data: &'a mut [u8], bit_size: BitSize) -> Option<Self> {
110        if !matches!(bit_size.0, 1 | 2 | 4 | 8 | 16) {
111            return None;
112        }
113
114        Some(Self {
115            data,
116            bit_size,
117            cur_pos: 0,
118        })
119    }
120
121    pub(crate) fn split_off(self) -> (&'a [u8], BitWriter<'a>) {
122        // Assumes that we are currently aligned to a byte boundary!
123        let (left, right) = self.data.split_at_mut(self.cur_pos / 8);
124        (
125            left,
126            BitWriter {
127                data: right,
128                cur_pos: 0,
129                bit_size: self.bit_size,
130            },
131        )
132    }
133
134    /// Align the writer to the next byte boundary.
135    #[cfg(feature = "jpeg2000")]
136    pub(crate) fn align(&mut self) {
137        let bit_pos = self.bit_pos();
138
139        if bit_pos % 8 != 0 {
140            self.cur_pos += 8 - bit_pos;
141        }
142    }
143
144    pub(crate) fn cur_pos(&self) -> usize {
145        self.cur_pos
146    }
147
148    pub(crate) fn get_data(&self) -> &[u8] {
149        self.data
150    }
151
152    fn byte_pos(&self) -> usize {
153        self.cur_pos / 8
154    }
155
156    fn bit_pos(&self) -> usize {
157        self.cur_pos % 8
158    }
159
160    pub(crate) fn write(&mut self, val: u16) -> Option<()> {
161        let byte_pos = self.byte_pos();
162        let bit_size = self.bit_size;
163
164        match bit_size.0 {
165            1 | 2 | 4 => {
166                let bit_pos = self.bit_pos();
167
168                let base = self.data.get(byte_pos)?;
169                let shift = 8 - self.bit_size.bits() - bit_pos;
170                let item = ((val & self.bit_size.mask() as u16) as u8) << shift;
171
172                *(self.data.get_mut(byte_pos)?) = *base | item;
173                self.cur_pos += bit_size.bits();
174            }
175            8 => {
176                *(self.data.get_mut(byte_pos)?) = val as u8;
177                self.cur_pos += 8;
178            }
179            16 => {
180                self.data
181                    .get_mut(byte_pos..(byte_pos + 2))?
182                    .copy_from_slice(&val.to_be_bytes());
183                self.cur_pos += 16;
184            }
185            _ => unreachable!(),
186        }
187
188        Some(())
189    }
190}
191
192pub(crate) struct BitChunks<'a> {
193    reader: BitReader<'a>,
194    bit_size: BitSize,
195    chunk_len: usize,
196}
197
198impl<'a> BitChunks<'a> {
199    pub(crate) fn new(data: &'a [u8], bit_size: BitSize, chunk_len: usize) -> Option<Self> {
200        if bit_size.0 > 16 {
201            warn!("BitChunks doesn't support working with bit sizes > 16.");
202
203            return None;
204        }
205
206        let reader = BitReader::new(data);
207
208        Some(Self {
209            reader,
210            bit_size,
211            chunk_len,
212        })
213    }
214}
215
216impl Iterator for BitChunks<'_> {
217    type Item = BitChunk;
218
219    fn next(&mut self) -> Option<Self::Item> {
220        let mut bits = SmallVec::new();
221
222        for _ in 0..self.chunk_len {
223            bits.push(self.reader.read(self.bit_size)? as u16);
224        }
225
226        Some(BitChunk { bits })
227    }
228}
229
230#[derive(Debug, Clone)]
231pub(crate) struct BitChunk {
232    bits: SmallVec<[u16; 4]>,
233}
234
235impl BitChunk {
236    pub(crate) fn iter(&self) -> impl Iterator<Item = u16> + '_ {
237        self.bits.iter().copied()
238    }
239
240    pub(crate) fn new(val: u8, count: usize) -> Self {
241        Self {
242            bits: smallvec![val as u16; count],
243        }
244    }
245
246    pub(crate) fn from_reader(
247        bit_reader: &mut BitReader,
248        bit_size: BitSize,
249        chunk_len: usize,
250    ) -> Option<Self> {
251        if bit_size.0 > 16 {
252            warn!("BitChunk doesn't support working with bit sizes > 16.");
253
254            return None;
255        }
256
257        let mut bits = SmallVec::new();
258
259        for _ in 0..chunk_len {
260            bits.push(bit_reader.read(bit_size)? as u16);
261        }
262
263        Some(BitChunk { bits })
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    const BS1: BitSize = BitSize(1);
272    const BS2: BitSize = BitSize(2);
273    const BS4: BitSize = BitSize(4);
274    const BS8: BitSize = BitSize(8);
275    const BS16: BitSize = BitSize(16);
276
277    #[test]
278    fn bit_reader_16() {
279        let data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
280        let mut reader = BitReader::new(&data);
281        assert_eq!(
282            reader.read(BS16).unwrap() as u16,
283            u16::from_be_bytes([0x01, 0x02])
284        );
285        assert_eq!(
286            reader.read(BS16).unwrap() as u16,
287            u16::from_be_bytes([0x03, 0x04])
288        );
289        assert_eq!(
290            reader.read(BS16).unwrap() as u16,
291            u16::from_be_bytes([0x05, 0x06])
292        );
293    }
294
295    #[test]
296    fn bit_writer_16() {
297        let mut buf = vec![0u8; 6];
298        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(16).unwrap()).unwrap();
299        writer.write(u16::from_be_bytes([0x01, 0x02])).unwrap();
300        writer.write(u16::from_be_bytes([0x03, 0x04])).unwrap();
301        writer.write(u16::from_be_bytes([0x05, 0x06])).unwrap();
302
303        assert_eq!(buf, [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]);
304    }
305
306    #[test]
307    fn bit_reader_12() {
308        let data = [0b10011000, 0b00011111, 0b10101001, 0b11101001, 0b00011010];
309        let mut reader = BitReader::new(&data);
310        assert_eq!(
311            reader.read(BitSize::from_u8(12).unwrap()).unwrap(),
312            0b100110000001
313        );
314        assert_eq!(
315            reader.read(BitSize::from_u8(12).unwrap()).unwrap(),
316            0b111110101001
317        );
318        assert_eq!(
319            reader.read(BitSize::from_u8(12).unwrap()).unwrap(),
320            0b111010010001
321        );
322    }
323
324    #[test]
325    fn bit_reader_9() {
326        let data = [0b10011000, 0b00011111, 0b10101001, 0b11101001, 0b00011010];
327        let mut reader = BitReader::new(&data);
328        assert_eq!(
329            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
330            0b100110000
331        );
332        assert_eq!(
333            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
334            0b001111110
335        );
336        assert_eq!(
337            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
338            0b101001111
339        );
340        assert_eq!(
341            reader.read(BitSize::from_u8(9).unwrap()).unwrap(),
342            0b010010001
343        );
344    }
345
346    #[test]
347    fn bit_writer_8() {
348        let mut buf = vec![0u8; 3];
349        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(8).unwrap()).unwrap();
350        writer.write(0x01).unwrap();
351        writer.write(0x02).unwrap();
352        writer.write(0x03).unwrap();
353
354        assert_eq!(buf, [0x01, 0x02, 0x03]);
355    }
356
357    #[test]
358    fn bit_reader_8() {
359        let data = [0x01, 0x02, 0x03];
360        let mut reader = BitReader::new(&data);
361        assert_eq!(reader.read(BS8).unwrap(), 0x01);
362        assert_eq!(reader.read(BS8).unwrap(), 0x02);
363        assert_eq!(reader.read(BS8).unwrap(), 0x03);
364    }
365
366    #[test]
367    fn bit_writer_4() {
368        let mut buf = vec![0u8; 3];
369        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(4).unwrap()).unwrap();
370        writer.write(0b1001).unwrap();
371        writer.write(0b1000).unwrap();
372        writer.write(0b0001).unwrap();
373        writer.write(0b1111).unwrap();
374        writer.write(0b1010).unwrap();
375        writer.write(0b1001).unwrap();
376
377        assert_eq!(buf, [0b10011000, 0b00011111, 0b10101001]);
378    }
379
380    #[test]
381    fn bit_reader_4() {
382        let data = [0b10011000, 0b00011111, 0b10101001];
383        let mut reader = BitReader::new(&data);
384        assert_eq!(reader.read(BS4).unwrap(), 0b1001);
385        assert_eq!(reader.read(BS4).unwrap(), 0b1000);
386        assert_eq!(reader.read(BS4).unwrap(), 0b0001);
387        assert_eq!(reader.read(BS4).unwrap(), 0b1111);
388        assert_eq!(reader.read(BS4).unwrap(), 0b1010);
389        assert_eq!(reader.read(BS4).unwrap(), 0b1001);
390    }
391
392    #[test]
393    fn bit_writer_2() {
394        let mut buf = vec![0u8; 2];
395        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(2).unwrap()).unwrap();
396        writer.write(0b10).unwrap();
397        writer.write(0b01).unwrap();
398        writer.write(0b10).unwrap();
399        writer.write(0b00).unwrap();
400        writer.write(0b00).unwrap();
401        writer.write(0b01).unwrap();
402        writer.write(0b00).unwrap();
403        writer.write(0b00).unwrap();
404
405        assert_eq!(buf, [0b10011000, 0b00010000]);
406    }
407
408    #[test]
409    fn bit_reader_2() {
410        let data = [0b10011000, 0b00010000];
411        let mut reader = BitReader::new(&data);
412        assert_eq!(reader.read(BS2).unwrap(), 0b10);
413        assert_eq!(reader.read(BS2).unwrap(), 0b01);
414        assert_eq!(reader.read(BS2).unwrap(), 0b10);
415        assert_eq!(reader.read(BS2).unwrap(), 0b00);
416        assert_eq!(reader.read(BS2).unwrap(), 0b00);
417        assert_eq!(reader.read(BS2).unwrap(), 0b01);
418        assert_eq!(reader.read(BS2).unwrap(), 0b00);
419        assert_eq!(reader.read(BS2).unwrap(), 0b00);
420    }
421
422    #[test]
423    fn bit_writer_1() {
424        let mut buf = vec![0u8; 2];
425        let mut writer = BitWriter::new(&mut buf, BitSize::from_u8(1).unwrap()).unwrap();
426        writer.write(0b1).unwrap();
427        writer.write(0b0).unwrap();
428        writer.write(0b0).unwrap();
429        writer.write(0b1).unwrap();
430        writer.write(0b1).unwrap();
431        writer.write(0b0).unwrap();
432        writer.write(0b0).unwrap();
433        writer.write(0b0).unwrap();
434
435        writer.write(0b0).unwrap();
436        writer.write(0b0).unwrap();
437        writer.write(0b0).unwrap();
438        writer.write(0b1).unwrap();
439        writer.write(0b0).unwrap();
440        writer.write(0b0).unwrap();
441        writer.write(0b0).unwrap();
442        writer.write(0b0).unwrap();
443
444        assert_eq!(buf, [0b10011000, 0b00010000]);
445    }
446
447    #[test]
448    fn bit_reader_1() {
449        let data = [0b10011000, 0b00010000];
450        let mut reader = BitReader::new(&data);
451        assert_eq!(reader.read(BS1).unwrap(), 0b1);
452        assert_eq!(reader.read(BS1).unwrap(), 0b0);
453        assert_eq!(reader.read(BS1).unwrap(), 0b0);
454        assert_eq!(reader.read(BS1).unwrap(), 0b1);
455        assert_eq!(reader.read(BS1).unwrap(), 0b1);
456        assert_eq!(reader.read(BS1).unwrap(), 0b0);
457        assert_eq!(reader.read(BS1).unwrap(), 0b0);
458        assert_eq!(reader.read(BS1).unwrap(), 0b0);
459
460        assert_eq!(reader.read(BS1).unwrap(), 0b0);
461        assert_eq!(reader.read(BS1).unwrap(), 0b0);
462        assert_eq!(reader.read(BS1).unwrap(), 0b0);
463        assert_eq!(reader.read(BS1).unwrap(), 0b1);
464        assert_eq!(reader.read(BS1).unwrap(), 0b0);
465        assert_eq!(reader.read(BS1).unwrap(), 0b0);
466        assert_eq!(reader.read(BS1).unwrap(), 0b0);
467        assert_eq!(reader.read(BS1).unwrap(), 0b0);
468    }
469
470    #[test]
471    fn bit_reader_align() {
472        let data = [0b10011000, 0b00010000];
473        let mut reader = BitReader::new(&data);
474        assert_eq!(reader.read(BS1).unwrap(), 0b1);
475        assert_eq!(reader.read(BS1).unwrap(), 0b0);
476        assert_eq!(reader.read(BS1).unwrap(), 0b0);
477        assert_eq!(reader.read(BS1).unwrap(), 0b1);
478        reader.align();
479
480        assert_eq!(reader.read(BS1).unwrap(), 0b0);
481        assert_eq!(reader.read(BS1).unwrap(), 0b0);
482        assert_eq!(reader.read(BS1).unwrap(), 0b0);
483        assert_eq!(reader.read(BS1).unwrap(), 0b1);
484        assert_eq!(reader.read(BS1).unwrap(), 0b0);
485        assert_eq!(reader.read(BS1).unwrap(), 0b0);
486        assert_eq!(reader.read(BS1).unwrap(), 0b0);
487        assert_eq!(reader.read(BS1).unwrap(), 0b0);
488    }
489
490    #[test]
491    fn bit_reader_chunks() {
492        let data = [0b10011000, 0b00010000];
493        let mut reader = BitChunks::new(&data, BitSize::from_u8(1).unwrap(), 3).unwrap();
494        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b1, 0b0, 0b0]);
495        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b1, 0b1, 0b0]);
496        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b0, 0b0, 0b0]);
497        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b0, 0b0, 0b1]);
498        assert_eq!(reader.next().unwrap().bits.as_ref(), &[0b0, 0b0, 0b0]);
499    }
500
501    #[test]
502    fn bit_reader_varying_bit_sizes() {
503        let data = [0b10011000, 0b00011111, 0b10101001];
504        let mut reader = BitReader::new(&data);
505        assert_eq!(reader.read(BS4).unwrap(), 0b1001);
506        assert_eq!(reader.read(BS1).unwrap(), 0b1);
507        assert_eq!(reader.read(BS4).unwrap(), 0b0000);
508        assert_eq!(reader.read(BitSize::from_u8(5).unwrap()).unwrap(), 0b00111);
509        assert_eq!(reader.read(BS1).unwrap(), 0b1);
510        assert_eq!(reader.read(BS2).unwrap(), 0b11);
511        assert_eq!(
512            reader.read(BitSize::from_u8(7).unwrap()).unwrap(),
513            0b0101001
514        );
515    }
516}