1use std::convert::TryInto;
2
3pub struct BitReader<'a> {
4 bytes: &'a [u8],
5 position: usize,
6}
7
8impl<'a> BitReader<'a> {
9 pub fn new(bytes: &'a [u8]) -> BitReader<'a> {
10 BitReader { bytes, position: 0 }
11 }
12
13 pub fn read_u8(&mut self, bit_count: u8) -> Option<u8> {
14 debug_assert!(bit_count <= 8);
15
16 if (self.position + usize::from(bit_count)) > self.bytes.len() * 8 {
17 return None;
18 }
19
20 let mut value = 0;
21 for bit in self.take(usize::from(bit_count)) {
22 value = (value << 1) | u8::from(bit);
23 }
24 Some(value)
25 }
26
27 pub fn read_u16(&mut self, bit_count: u8) -> Option<u16> {
28 debug_assert!(bit_count <= 16);
29
30 if (self.position + usize::from(bit_count)) > self.bytes.len() * 8 {
31 return None;
32 }
33
34 let mut value = 0;
35 for bit in self.take(usize::from(bit_count)) {
36 value = (value << 1) | u16::from(bit);
37 }
38 Some(value)
39 }
40
41 pub fn read_u32(&mut self, bit_count: u8) -> Option<u32> {
42 debug_assert!(bit_count <= 32);
43
44 if (self.position + usize::from(bit_count)) > self.bytes.len() * 8 {
45 return None;
46 }
47
48 let mut value = 0;
49 for bit in self.take(usize::from(bit_count)) {
50 value = (value << 1) | u32::from(bit);
51 }
52 Some(value)
53 }
54
55 pub fn read_u64(&mut self, bit_count: u8) -> Option<u64> {
56 debug_assert!(bit_count <= 64);
57
58 if (self.position + usize::from(bit_count)) > self.bytes.len() * 8 {
59 return None;
60 }
61
62 let mut value = 0;
63 for bit in self.take(usize::from(bit_count)) {
64 value = (value << 1) | u64::from(bit);
65 }
66 Some(value)
67 }
68
69 pub fn read_bool(&mut self) -> Option<bool> {
70 self.next()
71 }
72
73 pub fn advance_by(&mut self, bit_count: usize) -> bool {
75 let end_position = self.position + bit_count;
76 if end_position > self.bytes.len() * 8 {
77 return false;
78 }
79 self.position = end_position;
80 true
81 }
82
83 pub fn position(&self) -> usize {
85 self.position
86 }
87}
88
89impl<'a> Iterator for BitReader<'a> {
90 type Item = bool;
91
92 fn next(&mut self) -> Option<Self::Item> {
93 let position = self.position;
94 let byte = *self.bytes.get(position / 8)?;
95
96 self.position += 1;
97
98 let bit = byte << (position % 8);
99 Some(bit & 0b1000_0000 != 0)
100 }
101}
102
103#[cfg(target_pointer_width = "64")]
116pub struct CachedBitReader {
117 cache: u64,
118 read: usize,
119}
120
121#[cfg(target_pointer_width = "64")]
122impl CachedBitReader {
123 pub fn new(reader: &BitReader<'_>) -> Option<Self> {
124 let mut this = Self { cache: 0, read: 0 };
125 this.refresh(reader)?;
126 Some(this)
127 }
128
129 pub fn refresh(&mut self, reader: &BitReader<'_>) -> Option<()> {
130 let pos = reader.position / 8;
131 let data = reader.bytes.get(pos..pos + 8)?;
132
133 self.cache = u64::from_be_bytes(data.try_into().unwrap());
134 self.cache <<= reader.position % 8;
135 self.read = 0;
136 Some(())
137 }
138
139 pub fn read(&self) -> usize {
140 self.read
141 }
142
143 pub fn overflowed(&self) -> bool {
144 self.read() > 57
145 }
146
147 pub fn restore(&mut self, reader: &mut BitReader<'_>, read: usize) {
148 reader.position += read;
149 }
150}
151
152#[cfg(target_pointer_width = "64")]
153impl Iterator for CachedBitReader {
154 type Item = bool;
155
156 fn next(&mut self) -> Option<Self::Item> {
157 let bit = self.cache & !(u64::max_value() >> 1);
159
160 self.read += 1;
161 self.cache <<= 1;
162
163 Some(bit != 0)
164 }
165}