1use std::io;
2
3use better_io::BetterBufRead;
4
5use crate::bits;
6use crate::constants::Bitlen;
7use crate::errors::{PcoError, PcoResult};
8use crate::read_write_uint::ReadWriteUint;
9
10#[inline]
15pub unsafe fn u64_at(src: &[u8], byte_idx: usize) -> u64 {
16 let raw_bytes = *(src.as_ptr().add(byte_idx) as *const [u8; 8]);
17 u64::from_le_bytes(raw_bytes)
18}
19
20#[inline]
24pub unsafe fn u32_at(src: &[u8], byte_idx: usize) -> u32 {
25 let raw_bytes = *(src.as_ptr().add(byte_idx) as *const [u8; 4]);
26 u32::from_le_bytes(raw_bytes)
27}
28
29#[inline]
30pub unsafe fn read_uint_at<U: ReadWriteUint, const READ_BYTES: usize>(
31 src: &[u8],
32 byte_idx: usize,
33 bits_past_byte: Bitlen,
34 n: Bitlen,
35) -> U {
36 match READ_BYTES {
59 4 => read_u32_at(src, byte_idx, bits_past_byte, n),
60 8 => read_u64_at(src, byte_idx, bits_past_byte, n),
61 15 => read_almost_u64x2_at(src, byte_idx, bits_past_byte, n),
62 _ => unreachable!("invalid read bytes: {}", READ_BYTES),
63 }
64}
65
66#[inline]
67unsafe fn read_u32_at<U: ReadWriteUint>(
68 src: &[u8],
69 byte_idx: usize,
70 bits_past_byte: Bitlen,
71 n: Bitlen,
72) -> U {
73 debug_assert!(n <= 25);
74 U::from_u32(bits::lowest_bits_fast(
75 u32_at(src, byte_idx) >> bits_past_byte,
76 n,
77 ))
78}
79
80#[inline]
81unsafe fn read_u64_at<U: ReadWriteUint>(
82 src: &[u8],
83 byte_idx: usize,
84 bits_past_byte: Bitlen,
85 n: Bitlen,
86) -> U {
87 debug_assert!(n <= 57);
88 U::from_u64(bits::lowest_bits_fast(
89 u64_at(src, byte_idx) >> bits_past_byte,
90 n,
91 ))
92}
93
94#[inline]
95unsafe fn read_almost_u64x2_at<U: ReadWriteUint>(
96 src: &[u8],
97 byte_idx: usize,
98 bits_past_byte: Bitlen,
99 n: Bitlen,
100) -> U {
101 debug_assert!(n <= 113);
102 let first_word = U::from_u64(u64_at(src, byte_idx) >> bits_past_byte);
103 let processed = 56 - bits_past_byte;
104 let second_word = U::from_u64(u64_at(src, byte_idx + 7)) << processed;
105 bits::lowest_bits(first_word | second_word, n)
106}
107
108pub struct BitReader<'a> {
109 pub src: &'a [u8],
110 unpadded_bit_size: usize,
111
112 pub stale_byte_idx: usize, pub bits_past_byte: Bitlen, }
115
116impl<'a> BitReader<'a> {
117 pub fn new(src: &'a [u8], unpadded_byte_size: usize, bits_past_byte: Bitlen) -> Self {
118 Self {
119 src,
120 unpadded_bit_size: unpadded_byte_size * 8,
121 stale_byte_idx: 0,
122 bits_past_byte,
123 }
124 }
125
126 #[inline]
127 pub fn bit_idx(&self) -> usize {
128 self.stale_byte_idx * 8 + self.bits_past_byte as usize
129 }
130
131 fn byte_idx(&self) -> usize {
132 self.bit_idx() / 8
133 }
134
135 fn aligned_byte_idx(&self) -> PcoResult<usize> {
138 if self.bits_past_byte.is_multiple_of(8) {
139 Ok(self.byte_idx())
140 } else {
141 Err(PcoError::invalid_argument(format!(
142 "cannot get aligned byte index on misaligned bit reader (byte {} + {} bits)",
143 self.stale_byte_idx, self.bits_past_byte,
144 )))
145 }
146 }
147
148 #[inline]
149 fn refill(&mut self) {
150 self.stale_byte_idx += (self.bits_past_byte / 8) as usize;
151 self.bits_past_byte %= 8;
152 }
153
154 #[inline]
155 fn consume(&mut self, n: Bitlen) {
156 self.bits_past_byte += n;
157 }
158
159 pub fn read_aligned_bytes(&mut self, n: usize) -> PcoResult<&'a [u8]> {
160 let byte_idx = self.aligned_byte_idx()?;
161 let new_byte_idx = byte_idx + n;
162 self.stale_byte_idx = new_byte_idx;
163 self.bits_past_byte = 0;
164 Ok(&self.src[byte_idx..new_byte_idx])
165 }
166
167 pub unsafe fn read_uint<U: ReadWriteUint>(&mut self, n: Bitlen) -> U {
168 self.refill();
169 let res = match U::MAX_BYTES {
170 1..=4 => read_uint_at::<U, 4>(
171 self.src,
172 self.stale_byte_idx,
173 self.bits_past_byte,
174 n,
175 ),
176 5..=8 => read_uint_at::<U, 8>(
177 self.src,
178 self.stale_byte_idx,
179 self.bits_past_byte,
180 n,
181 ),
182 9..=15 => read_uint_at::<U, 15>(
183 self.src,
184 self.stale_byte_idx,
185 self.bits_past_byte,
186 n,
187 ),
188 _ => unreachable!(
189 "[BitReader] unsupported max bytes: {}",
190 U::MAX_BYTES
191 ),
192 };
193 self.consume(n);
194 res
195 }
196
197 pub unsafe fn read_usize(&mut self, n: Bitlen) -> usize {
198 self.read_uint(n)
199 }
200
201 pub unsafe fn read_bitlen(&mut self, n: Bitlen) -> Bitlen {
202 self.read_uint(n)
203 }
204
205 pub unsafe fn read_bool(&mut self) -> bool {
206 self.read_uint::<u32>(1) > 0
207 }
208
209 #[inline]
211 fn bit_idx_safe(&self) -> PcoResult<usize> {
212 let bit_idx = self.bit_idx();
213 if bit_idx > self.unpadded_bit_size {
214 return Err(PcoError::insufficient_data(format!(
215 "[BitReader] out of bounds at bit {} / {}",
216 bit_idx, self.unpadded_bit_size
217 )));
218 }
219 Ok(bit_idx)
220 }
221
222 pub fn check_in_bounds(&self) -> PcoResult<()> {
223 self.bit_idx_safe()?;
224 Ok(())
225 }
226
227 pub fn drain_empty_byte(&mut self, message: &str) -> PcoResult<()> {
231 self.check_in_bounds()?;
232 self.refill();
233 if self.bits_past_byte != 0 {
234 if (self.src[self.stale_byte_idx] >> self.bits_past_byte) > 0 {
235 return Err(PcoError::corruption(message));
236 }
237 self.consume(8 - self.bits_past_byte);
238 }
239 Ok(())
240 }
241}
242
243pub struct BitReaderBuilder<R: BetterBufRead> {
244 padding: usize,
245 inner: R,
246 eof_buffer: Vec<u8>,
247 reached_eof: bool,
248 bytes_into_eof_buffer: usize,
249 bits_past_byte: Bitlen,
250}
251
252impl<R: BetterBufRead> BitReaderBuilder<R> {
253 pub fn new(inner: R, padding: usize, bits_past_byte: Bitlen) -> Self {
254 Self {
255 padding,
256 inner,
257 eof_buffer: vec![],
258 reached_eof: false,
259 bytes_into_eof_buffer: 0,
260 bits_past_byte,
261 }
262 }
263
264 fn build<'a>(&'a mut self) -> io::Result<BitReader<'a>> {
265 let n_bytes_to_read = self.padding;
267 if !self.reached_eof {
268 self.inner.fill_or_eof(n_bytes_to_read)?;
269 let inner_bytes = self.inner.buffer();
270
271 if inner_bytes.len() < n_bytes_to_read {
272 self.reached_eof = true;
273 self.eof_buffer = vec![0; inner_bytes.len() + self.padding];
274 self.eof_buffer[..inner_bytes.len()].copy_from_slice(inner_bytes);
275 }
276 }
277
278 let src = if self.reached_eof {
279 &self.eof_buffer[self.bytes_into_eof_buffer..]
280 } else {
281 self.inner.buffer()
282 };
283
284 let unpadded_bytes = if self.reached_eof {
286 self.eof_buffer.len() - self.padding - self.bytes_into_eof_buffer
287 } else {
288 src.len()
289 };
290 let bits_past_byte = self.bits_past_byte;
291 Ok(BitReader::new(
292 src,
293 unpadded_bytes,
294 bits_past_byte,
295 ))
296 }
297
298 pub fn into_inner(self) -> R {
299 self.inner
300 }
301
302 fn update(&mut self, final_bit_idx: usize) {
303 let bytes_consumed = final_bit_idx / 8;
304 self.inner.consume(bytes_consumed);
305 if self.reached_eof {
306 self.bytes_into_eof_buffer += bytes_consumed;
307 }
308 self.bits_past_byte = final_bit_idx as Bitlen % 8;
309 }
310
311 pub fn with_reader<Y, F: FnOnce(&mut BitReader) -> PcoResult<Y>>(
312 &mut self,
313 f: F,
314 ) -> PcoResult<Y> {
315 let mut reader = self.build()?;
316 let res = f(&mut reader)?;
317 let final_bit_idx = reader.bit_idx_safe()?;
318 self.update(final_bit_idx);
319 Ok(res)
320 }
321}
322
323pub fn ensure_buf_read_capacity<R: BetterBufRead>(src: &mut R, required: usize) {
324 if let Some(current_capacity) = src.capacity() {
325 if current_capacity < required {
326 src.resize_capacity(required);
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use crate::constants::OVERSHOOT_PADDING;
334 use crate::errors::{ErrorKind, PcoResult};
335
336 use super::*;
337
338 #[test]
343 fn test_bit_reader() -> PcoResult<()> {
344 let mut src = vec![137, 38, 255, 65];
346 src.resize(20, 0);
347 let mut reader = BitReader::new(&src, 5, 0);
348
349 unsafe {
350 assert_eq!(reader.read_bitlen(4), 9);
351 assert!(reader.read_aligned_bytes(1).is_err());
352 assert_eq!(reader.read_bitlen(4), 8);
353 assert_eq!(reader.read_aligned_bytes(1)?, vec![38]);
354 assert_eq!(reader.read_usize(15), 255 + 65 * 256);
355 reader.drain_empty_byte("should be empty")?;
356 assert_eq!(reader.aligned_byte_idx()?, 4);
357 }
358 Ok(())
359 }
360
361 #[test]
362 fn test_bit_reader_builder() -> PcoResult<()> {
363 let src = (0..7).collect::<Vec<_>>();
364 let mut reader_builder = BitReaderBuilder::new(src.as_slice(), 4 + OVERSHOOT_PADDING, 1);
365 reader_builder.with_reader(|reader| unsafe {
366 assert_eq!(&reader.src[0..4], &vec![0, 1, 2, 3]);
367 assert_eq!(reader.bit_idx(), 1);
368 assert_eq!(reader.read_usize(16), 1 << 7); Ok(())
370 })?;
371 reader_builder.with_reader(|reader| unsafe {
372 assert_eq!(&reader.src[0..4], &vec![2, 3, 4, 5]);
373 assert_eq!(reader.bit_idx(), 1);
374 assert_eq!(reader.read_usize(7), 1);
375 assert_eq!(reader.bit_idx(), 8);
376 assert_eq!(reader.read_aligned_bytes(3)?, &vec![3, 4, 5]);
377 Ok(())
378 })?;
379 let err = reader_builder
380 .with_reader(|reader| unsafe {
381 assert!(reader.src.len() >= 4); reader.read_usize(9); Ok(())
384 })
385 .unwrap_err();
386 assert!(matches!(
387 err.kind,
388 ErrorKind::InsufficientData
389 ));
390
391 Ok(())
392 }
393}