pco/
bit_reader.rs

1use std::io;
2
3use better_io::BetterBufRead;
4
5use crate::bits;
6use crate::constants::Bitlen;
7use crate::errors::{PcoError, PcoResult};
8use crate::read_write_uint::ReadWriteUint;
9
10// Q: Why u64?
11// A: It's the largest data type most instruction sets have support for (and
12//    can do few-cycle/SIMD ops on). e.g. even 32-bit wasm has 64-bit ints and
13//    opcodes.
14#[inline]
15pub unsafe fn u64_at(src: &[u8], byte_idx: usize) -> u64 {
16  let raw_bytes = *(src.as_ptr().add(byte_idx) as *const [u8; 8]);
17  u64::from_le_bytes(raw_bytes)
18}
19
20// Q: Why is there also a u32 version?
21// A: This allows for better use of SIMD bandwidth when reading smaller latent
22//    types compared to the u64 version.
23#[inline]
24pub unsafe fn u32_at(src: &[u8], byte_idx: usize) -> u32 {
25  let raw_bytes = *(src.as_ptr().add(byte_idx) as *const [u8; 4]);
26  u32::from_le_bytes(raw_bytes)
27}
28
29#[inline]
30pub unsafe fn read_uint_at<U: ReadWriteUint, const READ_BYTES: usize>(
31  src: &[u8],
32  byte_idx: usize,
33  bits_past_byte: Bitlen,
34  n: Bitlen,
35) -> U {
36  // Q: Why is this fast?
37  // A: The compiler removes branching allowing for fast SIMD.
38  //
39  // Q: Why does this work?
40  // A: We set READ_BYTES so that,
41  //    0  to 25  bit reads -> 4 bytes (1 u32)
42  //    26 to 57  bit reads -> 8 bytes (1 u64)
43  //    58 to 113 bit reads -> 15 bytes (almost 2 u64's)
44  //    For the 1st u64, we read all bytes from the current u64. Due to our bit
45  //    packing, up to the first 7 of these may be useless, so we can read up
46  //    to (64 - 7) = 57 bits safely from a single u64. We right shift by only
47  //    up to 7 bits, which is safe.
48  //
49  //    For the 2nd u64, we skip only 7 bytes forward. This will overlap with
50  //    the 1st u64 by 1 byte, which seems useless, but allows us to avoid one
51  //    nasty case: left shifting by U::BITS (a panic). This could happen e.g.
52  //    with 64-bit reads when we start out byte-aligned (bits_past_byte=0).
53  //
54  //    For the 3rd u64 and onward (currently not implemented), we skip 8 bytes
55  //    forward. Due to how we handled the 2nd u64, the most we'll ever need to
56  //    shift by is U::BITS - 8, which is safe.
57
58  match READ_BYTES {
59    4 => read_u32_at(src, byte_idx, bits_past_byte, n),
60    8 => read_u64_at(src, byte_idx, bits_past_byte, n),
61    15 => read_almost_u64x2_at(src, byte_idx, bits_past_byte, n),
62    _ => unreachable!("invalid read bytes: {}", READ_BYTES),
63  }
64}
65
66#[inline]
67unsafe fn read_u32_at<U: ReadWriteUint>(
68  src: &[u8],
69  byte_idx: usize,
70  bits_past_byte: Bitlen,
71  n: Bitlen,
72) -> U {
73  debug_assert!(n <= 25);
74  U::from_u32(bits::lowest_bits_fast(
75    u32_at(src, byte_idx) >> bits_past_byte,
76    n,
77  ))
78}
79
80#[inline]
81unsafe fn read_u64_at<U: ReadWriteUint>(
82  src: &[u8],
83  byte_idx: usize,
84  bits_past_byte: Bitlen,
85  n: Bitlen,
86) -> U {
87  debug_assert!(n <= 57);
88  U::from_u64(bits::lowest_bits_fast(
89    u64_at(src, byte_idx) >> bits_past_byte,
90    n,
91  ))
92}
93
94#[inline]
95unsafe fn read_almost_u64x2_at<U: ReadWriteUint>(
96  src: &[u8],
97  byte_idx: usize,
98  bits_past_byte: Bitlen,
99  n: Bitlen,
100) -> U {
101  debug_assert!(n <= 113);
102  let first_word = U::from_u64(u64_at(src, byte_idx) >> bits_past_byte);
103  let processed = 56 - bits_past_byte;
104  let second_word = U::from_u64(u64_at(src, byte_idx + 7)) << processed;
105  bits::lowest_bits(first_word | second_word, n)
106}
107
108pub struct BitReader<'a> {
109  pub src: &'a [u8],
110  unpadded_bit_size: usize,
111
112  pub stale_byte_idx: usize,  // in current stream
113  pub bits_past_byte: Bitlen, // in current stream
114}
115
116impl<'a> BitReader<'a> {
117  pub fn new(src: &'a [u8], unpadded_byte_size: usize, bits_past_byte: Bitlen) -> Self {
118    Self {
119      src,
120      unpadded_bit_size: unpadded_byte_size * 8,
121      stale_byte_idx: 0,
122      bits_past_byte,
123    }
124  }
125
126  #[inline]
127  pub fn bit_idx(&self) -> usize {
128    self.stale_byte_idx * 8 + self.bits_past_byte as usize
129  }
130
131  fn byte_idx(&self) -> usize {
132    self.bit_idx() / 8
133  }
134
135  // Returns the reader's current byte index. Will return an error if the
136  // reader is at a misaligned position.
137  fn aligned_byte_idx(&self) -> PcoResult<usize> {
138    if self.bits_past_byte.is_multiple_of(8) {
139      Ok(self.byte_idx())
140    } else {
141      Err(PcoError::invalid_argument(format!(
142        "cannot get aligned byte index on misaligned bit reader (byte {} + {} bits)",
143        self.stale_byte_idx, self.bits_past_byte,
144      )))
145    }
146  }
147
148  #[inline]
149  fn refill(&mut self) {
150    self.stale_byte_idx += (self.bits_past_byte / 8) as usize;
151    self.bits_past_byte %= 8;
152  }
153
154  #[inline]
155  fn consume(&mut self, n: Bitlen) {
156    self.bits_past_byte += n;
157  }
158
159  pub fn read_aligned_bytes(&mut self, n: usize) -> PcoResult<&'a [u8]> {
160    let byte_idx = self.aligned_byte_idx()?;
161    let new_byte_idx = byte_idx + n;
162    self.stale_byte_idx = new_byte_idx;
163    self.bits_past_byte = 0;
164    Ok(&self.src[byte_idx..new_byte_idx])
165  }
166
167  pub unsafe fn read_uint<U: ReadWriteUint>(&mut self, n: Bitlen) -> U {
168    self.refill();
169    let res = match U::MAX_BYTES {
170      1..=4 => read_uint_at::<U, 4>(
171        self.src,
172        self.stale_byte_idx,
173        self.bits_past_byte,
174        n,
175      ),
176      5..=8 => read_uint_at::<U, 8>(
177        self.src,
178        self.stale_byte_idx,
179        self.bits_past_byte,
180        n,
181      ),
182      9..=15 => read_uint_at::<U, 15>(
183        self.src,
184        self.stale_byte_idx,
185        self.bits_past_byte,
186        n,
187      ),
188      _ => unreachable!(
189        "[BitReader] unsupported max bytes: {}",
190        U::MAX_BYTES
191      ),
192    };
193    self.consume(n);
194    res
195  }
196
197  pub unsafe fn read_usize(&mut self, n: Bitlen) -> usize {
198    self.read_uint(n)
199  }
200
201  pub unsafe fn read_bitlen(&mut self, n: Bitlen) -> Bitlen {
202    self.read_uint(n)
203  }
204
205  pub unsafe fn read_bool(&mut self) -> bool {
206    self.read_uint::<u32>(1) > 0
207  }
208
209  // checks in bounds and returns bit idx
210  #[inline]
211  fn bit_idx_safe(&self) -> PcoResult<usize> {
212    let bit_idx = self.bit_idx();
213    if bit_idx > self.unpadded_bit_size {
214      return Err(PcoError::insufficient_data(format!(
215        "[BitReader] out of bounds at bit {} / {}",
216        bit_idx, self.unpadded_bit_size
217      )));
218    }
219    Ok(bit_idx)
220  }
221
222  pub fn check_in_bounds(&self) -> PcoResult<()> {
223    self.bit_idx_safe()?;
224    Ok(())
225  }
226
227  // Seek to the end of the byte, asserting it's all 0.
228  // Used to terminate each section of the file, since they
229  // always start and end byte-aligned.
230  pub fn drain_empty_byte(&mut self, message: &str) -> PcoResult<()> {
231    self.check_in_bounds()?;
232    self.refill();
233    if self.bits_past_byte != 0 {
234      if (self.src[self.stale_byte_idx] >> self.bits_past_byte) > 0 {
235        return Err(PcoError::corruption(message));
236      }
237      self.consume(8 - self.bits_past_byte);
238    }
239    Ok(())
240  }
241}
242
243pub struct BitReaderBuilder<R: BetterBufRead> {
244  padding: usize,
245  inner: R,
246  eof_buffer: Vec<u8>,
247  reached_eof: bool,
248  bytes_into_eof_buffer: usize,
249  bits_past_byte: Bitlen,
250}
251
252impl<R: BetterBufRead> BitReaderBuilder<R> {
253  pub fn new(inner: R, padding: usize, bits_past_byte: Bitlen) -> Self {
254    Self {
255      padding,
256      inner,
257      eof_buffer: vec![],
258      reached_eof: false,
259      bytes_into_eof_buffer: 0,
260      bits_past_byte,
261    }
262  }
263
264  fn build<'a>(&'a mut self) -> io::Result<BitReader<'a>> {
265    // could make n_bytes configurably smaller if it matters for some reason
266    let n_bytes_to_read = self.padding;
267    if !self.reached_eof {
268      self.inner.fill_or_eof(n_bytes_to_read)?;
269      let inner_bytes = self.inner.buffer();
270
271      if inner_bytes.len() < n_bytes_to_read {
272        self.reached_eof = true;
273        self.eof_buffer = vec![0; inner_bytes.len() + self.padding];
274        self.eof_buffer[..inner_bytes.len()].copy_from_slice(inner_bytes);
275      }
276    }
277
278    let src = if self.reached_eof {
279      &self.eof_buffer[self.bytes_into_eof_buffer..]
280    } else {
281      self.inner.buffer()
282    };
283
284    // we've reached the end of file buffer
285    let unpadded_bytes = if self.reached_eof {
286      self.eof_buffer.len() - self.padding - self.bytes_into_eof_buffer
287    } else {
288      src.len()
289    };
290    let bits_past_byte = self.bits_past_byte;
291    Ok(BitReader::new(
292      src,
293      unpadded_bytes,
294      bits_past_byte,
295    ))
296  }
297
298  pub fn into_inner(self) -> R {
299    self.inner
300  }
301
302  fn update(&mut self, final_bit_idx: usize) {
303    let bytes_consumed = final_bit_idx / 8;
304    self.inner.consume(bytes_consumed);
305    if self.reached_eof {
306      self.bytes_into_eof_buffer += bytes_consumed;
307    }
308    self.bits_past_byte = final_bit_idx as Bitlen % 8;
309  }
310
311  pub fn with_reader<Y, F: FnOnce(&mut BitReader) -> PcoResult<Y>>(
312    &mut self,
313    f: F,
314  ) -> PcoResult<Y> {
315    let mut reader = self.build()?;
316    let res = f(&mut reader)?;
317    let final_bit_idx = reader.bit_idx_safe()?;
318    self.update(final_bit_idx);
319    Ok(res)
320  }
321}
322
323pub fn ensure_buf_read_capacity<R: BetterBufRead>(src: &mut R, required: usize) {
324  if let Some(current_capacity) = src.capacity() {
325    if current_capacity < required {
326      src.resize_capacity(required);
327    }
328  }
329}
330
331#[cfg(test)]
332mod tests {
333  use crate::constants::OVERSHOOT_PADDING;
334  use crate::errors::{ErrorKind, PcoResult};
335
336  use super::*;
337
338  // I find little endian confusing, hence all the comments.
339  // All the bytes in comments are written backwards,
340  // e.g. 00000001 = 2^7
341
342  #[test]
343  fn test_bit_reader() -> PcoResult<()> {
344    // 10010001 01100100 00000000 11111111 10000010
345    let mut src = vec![137, 38, 255, 65];
346    src.resize(20, 0);
347    let mut reader = BitReader::new(&src, 5, 0);
348
349    unsafe {
350      assert_eq!(reader.read_bitlen(4), 9);
351      assert!(reader.read_aligned_bytes(1).is_err());
352      assert_eq!(reader.read_bitlen(4), 8);
353      assert_eq!(reader.read_aligned_bytes(1)?, vec![38]);
354      assert_eq!(reader.read_usize(15), 255 + 65 * 256);
355      reader.drain_empty_byte("should be empty")?;
356      assert_eq!(reader.aligned_byte_idx()?, 4);
357    }
358    Ok(())
359  }
360
361  #[test]
362  fn test_bit_reader_builder() -> PcoResult<()> {
363    let src = (0..7).collect::<Vec<_>>();
364    let mut reader_builder = BitReaderBuilder::new(src.as_slice(), 4 + OVERSHOOT_PADDING, 1);
365    reader_builder.with_reader(|reader| unsafe {
366      assert_eq!(&reader.src[0..4], &vec![0, 1, 2, 3]);
367      assert_eq!(reader.bit_idx(), 1);
368      assert_eq!(reader.read_usize(16), 1 << 7); // not 1 << 8, because we started at bit_idx 1
369      Ok(())
370    })?;
371    reader_builder.with_reader(|reader| unsafe {
372      assert_eq!(&reader.src[0..4], &vec![2, 3, 4, 5]);
373      assert_eq!(reader.bit_idx(), 1);
374      assert_eq!(reader.read_usize(7), 1);
375      assert_eq!(reader.bit_idx(), 8);
376      assert_eq!(reader.read_aligned_bytes(3)?, &vec![3, 4, 5]);
377      Ok(())
378    })?;
379    let err = reader_builder
380      .with_reader(|reader| unsafe {
381        assert!(reader.src.len() >= 4); // because of padding
382        reader.read_usize(9); // this overshoots the end of the data by 1 bit
383        Ok(())
384      })
385      .unwrap_err();
386    assert!(matches!(
387      err.kind,
388      ErrorKind::InsufficientData
389    ));
390
391    Ok(())
392  }
393}