dsi_bitstream/impls/
bit_reader.rs1use core::convert::Infallible;
10#[cfg(feature = "mem_dbg")]
11use mem_dbg::{MemDbg, MemSize};
12use std::error::Error;
13
14use crate::codes::params::{DefaultReadParams, ReadParams};
15use crate::traits::*;
16
17#[derive(Debug, Clone)]
37#[cfg_attr(feature = "mem_dbg", derive(MemDbg, MemSize))]
38pub struct BitReader<E: Endianness, WR, RP: ReadParams = DefaultReadParams> {
39 data: WR,
41 bit_index: u64,
43 _marker: core::marker::PhantomData<(E, RP)>,
44}
45
46impl<E: Endianness, WR, RP: ReadParams> BitReader<E, WR, RP> {
47 pub fn new(data: WR) -> Self {
48 check_tables(32);
49 Self {
50 data,
51 bit_index: 0,
52 _marker: core::marker::PhantomData,
53 }
54 }
55}
56
57impl<
58 E: Error + Send + Sync + 'static,
59 WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
60 RP: ReadParams,
61 > BitRead<BE> for BitReader<BE, WR, RP>
62{
63 type Error = <WR as WordRead>::Error;
64 type PeekWord = u32;
65
66 #[inline]
67 fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
68 self.bit_index += n_bits as u64;
69 Ok(())
70 }
71
72 #[inline]
73 fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
74 if n_bits == 0 {
75 return Ok(0);
76 }
77
78 assert!(n_bits <= 64);
79
80 self.data.set_word_pos(self.bit_index / 64)?;
81 let in_word_offset = (self.bit_index % 64) as usize;
82
83 let res = if (in_word_offset + n_bits) <= 64 {
84 let word = self.data.read_word()?.to_be();
86 (word << in_word_offset) >> (64 - n_bits)
87 } else {
88 let high_word = self.data.read_word()?.to_be();
90 let low_word = self.data.read_word()?.to_be();
91 let shamt1 = 64 - n_bits;
92 let shamt2 = 128 - in_word_offset - n_bits;
93 ((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
94 };
95 self.bit_index += n_bits as u64;
96 Ok(res)
97 }
98
99 #[inline]
100 fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
101 if n_bits == 0 {
102 return Ok(0);
103 }
104
105 assert!(n_bits <= 32);
106
107 self.data.set_word_pos(self.bit_index / 64)?;
108 let in_word_offset = (self.bit_index % 64) as usize;
109
110 let res = if (in_word_offset + n_bits) <= 64 {
111 let word = self.data.read_word()?.to_be();
113 (word << in_word_offset) >> (64 - n_bits)
114 } else {
115 let high_word = self.data.read_word()?.to_be();
117 let low_word = self.data.read_word()?.to_be();
118 let shamt1 = 64 - n_bits;
119 let shamt2 = 128 - in_word_offset - n_bits;
120 ((high_word << in_word_offset) >> shamt1) | (low_word >> shamt2)
121 };
122 Ok(res as u32)
123 }
124
125 #[inline]
126 fn read_unary(&mut self) -> Result<u64, Self::Error> {
127 self.data.set_word_pos(self.bit_index / 64)?;
128 let in_word_offset = self.bit_index % 64;
129 let mut bits_in_word = 64 - in_word_offset;
130 let mut total = 0;
131
132 let mut word = self.data.read_word()?.to_be();
133 word <<= in_word_offset;
134 loop {
135 let zeros = word.leading_zeros() as u64;
136 if zeros < bits_in_word {
138 self.bit_index += total + zeros + 1;
139 return Ok(total + zeros);
140 }
141 total += bits_in_word;
142 bits_in_word = 64;
143 word = self.data.read_word()?.to_be();
144 }
145 }
146
147 #[inline(always)]
148 fn skip_bits_after_peek(&mut self, n: usize) {
149 self.bit_index += n as u64;
150 }
151}
152
153impl<WR: WordSeek, RP: ReadParams> BitSeek for BitReader<LE, WR, RP> {
154 type Error = Infallible;
155
156 fn bit_pos(&mut self) -> Result<u64, Self::Error> {
157 Ok(self.bit_index)
158 }
159
160 fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
161 self.bit_index = bit_index;
162 Ok(())
163 }
164}
165
166impl<WR: WordSeek, RP: ReadParams> BitSeek for BitReader<BE, WR, RP> {
167 type Error = Infallible;
168
169 fn bit_pos(&mut self) -> Result<u64, Self::Error> {
170 Ok(self.bit_index)
171 }
172
173 fn set_bit_pos(&mut self, bit_index: u64) -> Result<(), Self::Error> {
174 self.bit_index = bit_index;
175 Ok(())
176 }
177}
178
179impl<
180 E: Error + Send + Sync + 'static,
181 WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
182 RP: ReadParams,
183 > BitRead<LE> for BitReader<LE, WR, RP>
184{
185 type Error = <WR as WordRead>::Error;
186 type PeekWord = u32;
187
188 #[inline]
189 fn skip_bits(&mut self, n_bits: usize) -> Result<(), Self::Error> {
190 self.bit_index += n_bits as u64;
191 Ok(())
192 }
193
194 #[inline]
195 fn read_bits(&mut self, n_bits: usize) -> Result<u64, Self::Error> {
196 #[cfg(feature = "checks")]
197 assert!(n_bits <= 64);
198
199 if n_bits == 0 {
200 return Ok(0);
201 }
202
203 self.data.set_word_pos(self.bit_index / 64)?;
204 let in_word_offset = (self.bit_index % 64) as usize;
205
206 let res = if (in_word_offset + n_bits) <= 64 {
207 let word = self.data.read_word()?.to_le();
209 let shamt = 64 - n_bits;
210 (word << (shamt - in_word_offset)) >> shamt
211 } else {
212 let low_word = self.data.read_word()?.to_le();
214 let high_word = self.data.read_word()?.to_le();
215 let shamt1 = 128 - in_word_offset - n_bits;
216 let shamt2 = 64 - n_bits;
217 ((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
218 };
219 self.bit_index += n_bits as u64;
220 Ok(res)
221 }
222
223 #[inline]
224 fn peek_bits(&mut self, n_bits: usize) -> Result<u32, Self::Error> {
225 if n_bits == 0 {
226 return Ok(0);
227 }
228
229 assert!(n_bits <= 32);
230
231 self.data.set_word_pos(self.bit_index / 64)?;
232 let in_word_offset = (self.bit_index % 64) as usize;
233
234 let res = if (in_word_offset + n_bits) <= 64 {
235 let word = self.data.read_word()?.to_le();
237 let shamt = 64 - n_bits;
238 (word << (shamt - in_word_offset)) >> shamt
239 } else {
240 let low_word = self.data.read_word()?.to_le();
242 let high_word = self.data.read_word()?.to_le();
243 let shamt1 = 128 - in_word_offset - n_bits;
244 let shamt2 = 64 - n_bits;
245 ((high_word << shamt1) >> shamt2) | (low_word >> in_word_offset)
246 };
247 Ok(res as u32)
248 }
249
250 #[inline]
251 fn read_unary(&mut self) -> Result<u64, Self::Error> {
252 self.data.set_word_pos(self.bit_index / 64)?;
253 let in_word_offset = self.bit_index % 64;
254 let mut bits_in_word = 64 - in_word_offset;
255 let mut total = 0;
256
257 let mut word = self.data.read_word()?.to_le();
258 word >>= in_word_offset;
259 loop {
260 let zeros = word.trailing_zeros() as u64;
261 if zeros < bits_in_word {
263 self.bit_index += total + zeros + 1;
264 return Ok(total + zeros);
265 }
266 total += bits_in_word;
267 bits_in_word = 64;
268 word = self.data.read_word()?.to_le();
269 }
270 }
271
272 #[inline(always)]
273 fn skip_bits_after_peek(&mut self, n: usize) {
274 self.bit_index += n as u64;
275 }
276}
277
278impl<
279 E: Error + Send + Sync + 'static,
280 WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
281 RP: ReadParams,
282 > std::io::Read for BitReader<LE, WR, RP>
283{
284 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
285 let mut iter = buf.chunks_exact_mut(8);
286
287 for chunk in &mut iter {
288 let word = self
289 .read_bits(64)
290 .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
291 chunk.copy_from_slice(&word.to_le_bytes());
292 }
293
294 let rem = iter.into_remainder();
295 if !rem.is_empty() {
296 let word = self
297 .read_bits(rem.len() * 8)
298 .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
299 rem.copy_from_slice(&word.to_le_bytes()[..rem.len()]);
300 }
301
302 Ok(buf.len())
303 }
304}
305
306impl<
307 E: Error + Send + Sync + 'static,
308 WR: WordRead<Error = E, Word = u64> + WordSeek<Error = E>,
309 RP: ReadParams,
310 > std::io::Read for BitReader<BE, WR, RP>
311{
312 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
313 let mut iter = buf.chunks_exact_mut(8);
314
315 for chunk in &mut iter {
316 let word = self
317 .read_bits(64)
318 .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
319 chunk.copy_from_slice(&word.to_be_bytes());
320 }
321
322 let rem = iter.into_remainder();
323 if !rem.is_empty() {
324 let word = self
325 .read_bits(rem.len() * 8)
326 .map_err(|_| std::io::ErrorKind::UnexpectedEof)?;
327 rem.copy_from_slice(&word.to_be_bytes()[8 - rem.len()..]);
328 }
329
330 Ok(buf.len())
331 }
332}