1use std::cmp::min;
2use std::fmt::{Debug, Display};
3use std::ops::*;
4
5use crate::bit_words::BitWords;
6use crate::bits;
7use crate::constants::{BITS_TO_ENCODE_N_ENTRIES, BYTES_PER_WORD, WORD_SIZE};
8use crate::data_types::UnsignedLike;
9use crate::errors::{QCompressError, QCompressResult};
10
11pub(crate) trait ReadableUint:
12 Add<Output = Self>
13 + BitAnd<Output = Self>
14 + BitOr<Output = Self>
15 + BitOrAssign
16 + Copy
17 + Debug
18 + Display
19 + Shl<usize, Output = Self>
20 + Shr<usize, Output = Self>
21{
22 const ZERO: Self;
23 const MAX: Self;
24 const BITS: usize;
25
26 fn from_word(word: usize) -> Self;
27}
28
29impl ReadableUint for usize {
30 const ZERO: Self = 0;
31 const MAX: Self = 0;
32 const BITS: usize = WORD_SIZE;
33
34 fn from_word(word: usize) -> Self {
35 word
36 }
37}
38
39impl<U: UnsignedLike> ReadableUint for U {
40 const ZERO: Self = <Self as UnsignedLike>::ZERO;
41 const MAX: Self = <Self as UnsignedLike>::MAX;
42 const BITS: usize = <Self as UnsignedLike>::BITS;
43
44 fn from_word(word: usize) -> Self {
45 <Self as UnsignedLike>::from_word(word)
46 }
47}
48
49#[derive(Clone)]
61pub struct BitReader<'a> {
62 word: usize,
66 words: &'a [usize],
67 i: usize,
68 j: usize,
69 total_bits: usize,
70}
71
72impl<'a> From<&'a BitWords> for BitReader<'a> {
73 fn from(bit_words: &'a BitWords) -> Self {
74 let word = bit_words.words.first().copied().unwrap_or_default();
75 BitReader {
76 word,
77 words: &bit_words.words,
78 i: 0,
79 j: 0,
80 total_bits: bit_words.total_bits,
81 }
82 }
83}
84
85impl<'a> BitReader<'a> {
86 pub fn aligned_byte_idx(&self) -> QCompressResult<usize> {
90 if self.j % 8 == 0 {
91 Ok(self.i * BYTES_PER_WORD + self.j / 8)
92 } else {
93 Err(QCompressError::invalid_argument(format!(
94 "cannot get aligned byte index on misaligned bit reader at word {} bit {}",
95 self.i, self.j,
96 )))
97 }
98 }
99
100 pub fn bit_idx(&self) -> usize {
101 WORD_SIZE * self.i + self.j
102 }
103
104 pub fn bits_remaining(&self) -> usize {
107 self.total_bits - self.bit_idx()
108 }
109
110 pub fn byte_size(&self) -> usize {
112 bits::ceil_div(self.total_bits, 8)
113 }
114
115 fn increment_i(&mut self) {
116 self.i += 1;
117 self.update_unsafe_word();
118 }
119
120 fn update_unsafe_word(&mut self) {
121 self.word = self.words[self.i];
122 }
123
124 #[inline]
125 fn refresh_if_needed(&mut self) {
126 if self.j == WORD_SIZE {
127 self.increment_i();
128 self.j = 0;
129 }
130 }
131
132 fn insufficient_data_check(&self, name: &str, n: usize) -> QCompressResult<()> {
133 let bit_idx = self.bit_idx();
134 if bit_idx + n > self.total_bits {
135 Err(QCompressError::insufficient_data_recipe(
136 name,
137 n,
138 bit_idx,
139 self.total_bits,
140 ))
141 } else {
142 Ok(())
143 }
144 }
145
146 pub fn read_aligned_bytes(&mut self, n: usize) -> QCompressResult<Vec<u8>> {
150 let byte_idx = self.aligned_byte_idx()?;
151 let new_byte_idx = byte_idx + n;
152 let byte_size = self.byte_size();
153 if new_byte_idx > byte_size {
154 Err(QCompressError::insufficient_data(format!(
155 "cannot read {} aligned bytes at byte idx {} out of {}",
156 n, byte_idx, byte_size,
157 )))
158 } else {
159 self.refresh_if_needed();
160
161 let end_word_idx = bits::ceil_div(new_byte_idx, BYTES_PER_WORD);
162 let padded_bytes = bits::words_to_bytes(&self.words[byte_idx / BYTES_PER_WORD..end_word_idx]);
163
164 self.seek(n * 8);
165 let padded_start_idx = byte_idx % BYTES_PER_WORD;
166 Ok(padded_bytes[padded_start_idx..padded_start_idx + n].to_vec())
167 }
168 }
169
170 pub fn read_one(&mut self) -> QCompressResult<bool> {
173 self.insufficient_data_check("read_one", 1)?;
174 Ok(self.unchecked_read_one())
175 }
176
177 pub fn read(&mut self, n: usize) -> QCompressResult<Vec<bool>> {
180 self.insufficient_data_check("read", n)?;
181
182 let mut res = Vec::with_capacity(n);
183
184 for _ in 0..n {
186 res.push(self.unchecked_read_one());
187 }
188 Ok(res)
189 }
190
191 pub(crate) fn read_uint<U: ReadableUint>(&mut self, n: usize) -> QCompressResult<U> {
192 self.insufficient_data_check("read_uint", n)?;
193
194 Ok(self.unchecked_read_uint::<U>(n))
195 }
196
197 pub fn read_usize(&mut self, n: usize) -> QCompressResult<usize> {
198 self.read_uint::<usize>(n)
199 }
200
201 pub fn read_prefix_table_idx(
203 &mut self,
204 table_size_log: usize,
205 ) -> QCompressResult<(usize, usize)> {
206 let bit_idx = self.bit_idx();
207 if bit_idx >= self.total_bits {
208 return Err(QCompressError::insufficient_data_recipe(
209 "read_prefix_table_idx",
210 1,
211 bit_idx,
212 self.total_bits,
213 ));
214 }
215
216 self.refresh_if_needed();
217
218 let n_plus_j = table_size_log + self.j;
219 if n_plus_j <= WORD_SIZE {
220 let rshift = WORD_SIZE - n_plus_j;
221 let res = (self.word & (usize::MAX >> self.j)) >> rshift;
222 let bits_read = min(table_size_log, self.total_bits - bit_idx);
223 self.j += bits_read;
224 Ok((bits_read, res))
225 } else {
226 let remaining = n_plus_j - WORD_SIZE;
227 let mut res = (self.word & (usize::MAX >> self.j)) << remaining;
228 if self.i + 1 < self.words.len() {
229 self.increment_i();
230 let shift = WORD_SIZE - remaining;
231 res |= self.word >> shift;
232 self.j = remaining;
233 Ok((table_size_log, res))
234 } else {
235 self.j = WORD_SIZE;
236 Ok((table_size_log - remaining, res))
237 }
238 }
239 }
240
241 pub fn read_varint(&mut self, jumpstart: usize) -> QCompressResult<usize> {
242 let mut res = self.read_usize(jumpstart)?;
243 for i in jumpstart..BITS_TO_ENCODE_N_ENTRIES {
244 if self.read_one()? {
245 if self.read_one()? {
246 res |= 1 << i
247 }
248 } else {
249 break;
250 }
251 }
252 Ok(res)
253 }
254
255 pub fn unchecked_read_one(&mut self) -> bool {
258 self.refresh_if_needed();
259
260 let res = bits::bit_from_word(self.word, self.j);
261 self.j += 1;
262 res
263 }
264
265 pub(crate) fn unchecked_read_uint<U: ReadableUint>(&mut self, n: usize) -> U {
266 if n == 0 {
267 return U::ZERO;
268 }
269
270 self.refresh_if_needed();
271
272 let n_plus_j = n + self.j;
273 let first_masked_word = self.word & (usize::MAX >> self.j);
274 if n_plus_j <= WORD_SIZE {
275 let shift = WORD_SIZE - n_plus_j;
277 self.j = n_plus_j;
278 U::from_word(first_masked_word >> shift)
279 } else {
280 let mut remaining = n_plus_j - WORD_SIZE;
281 let mut res = U::from_word(first_masked_word << remaining);
282 self.increment_i();
283 for _ in 0..(U::BITS - 1) / WORD_SIZE {
287 if remaining <= WORD_SIZE {
288 break;
289 }
290 remaining -= WORD_SIZE;
291 res |= U::from_word(self.word) << remaining;
292 self.increment_i();
293 }
294
295 self.j = remaining;
296 let shift = WORD_SIZE - remaining;
297 res | U::from_word(self.word >> shift)
298 }
299 }
300
301 #[inline]
302 pub fn unchecked_read_usize(&mut self, n: usize) -> usize {
303 self.unchecked_read_uint::<usize>(n)
304 }
305
306 pub fn unchecked_read_varint(&mut self, jumpstart: usize) -> usize {
307 let mut res = self.unchecked_read_usize(jumpstart);
308 for i in jumpstart..BITS_TO_ENCODE_N_ENTRIES {
309 if self.unchecked_read_one() {
310 if self.unchecked_read_one() {
311 res |= 1 << i
312 }
313 } else {
314 break;
315 }
316 }
317 res
318 }
319
320 pub fn drain_empty_byte(&mut self, message: &str) -> QCompressResult<()> {
324 if self.j % 8 != 0 {
325 let end_j = 8 * bits::ceil_div(self.j, 8);
326 if self.word & (usize::MAX >> self.j) & (usize::MAX << (WORD_SIZE - end_j)) > 0 {
327 return Err(QCompressError::corruption(message));
328 }
329 self.j = end_j;
330 }
331 Ok(())
332 }
333
334 pub fn seek_to(&mut self, bit_idx: usize) {
337 self.i = bit_idx.div_euclid(WORD_SIZE);
338 self.j = bit_idx.rem_euclid(WORD_SIZE);
339 self.word = self.words.get(self.i).copied().unwrap_or(0);
340 }
341
342 pub fn seek(&mut self, n: usize) {
347 self.seek_to(self.bit_idx() + n);
348 }
349
350 pub fn rewind_prefix_overshoot(&mut self, n: usize) {
353 if n <= self.j {
354 self.j -= n;
355 } else {
356 self.i -= 1;
357 self.j = self.j + WORD_SIZE - n;
358 self.update_unsafe_word();
359 }
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::BitReader;
366 use crate::bit_words::BitWords;
367 use crate::errors::QCompressResult;
368
369 #[test]
370 fn test_bit_reader() -> QCompressResult<()> {
371 let bytes = vec![0x9a, 0x6b, 0x2d];
373 let words = BitWords::from(&bytes);
374 let mut bit_reader = BitReader::from(&words);
375 assert_eq!(bit_reader.read_aligned_bytes(1)?, vec![0x9a],);
376 assert!(!bit_reader.unchecked_read_one());
377 assert!(bit_reader.read_one()?);
378 assert_eq!(bit_reader.read(3)?, vec![true, false, true],);
379 assert_eq!(
380 bit_reader.unchecked_read_uint::<u64>(2),
381 1_u64
382 );
383 assert_eq!(
384 bit_reader.unchecked_read_uint::<u32>(3),
385 4_u32
386 );
387 assert_eq!(bit_reader.unchecked_read_varint(2), 6);
388 Ok(())
390 }
391
392 #[test]
393 fn test_seek_rewind() {
394 let bytes = vec![0; 6];
395 let words = BitWords::from(&bytes);
396 let mut reader = BitReader::from(&words);
397 reader.seek(43);
398
399 reader.rewind_prefix_overshoot(2);
400 assert_eq!(reader.bit_idx(), 41);
401 reader.rewind_prefix_overshoot(2);
402 assert_eq!(reader.bit_idx(), 39);
403 reader.rewind_prefix_overshoot(7);
404 assert_eq!(reader.bit_idx(), 32);
405 reader.rewind_prefix_overshoot(8);
406 assert_eq!(reader.bit_idx(), 24);
407 reader.rewind_prefix_overshoot(17);
408 assert_eq!(reader.bit_idx(), 7);
409 }
410}