1use crate::decode;
13use crate::{Cmr, FailEntropy};
14use std::{error, fmt};
15
16#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
18pub struct EarlyEndOfStreamError;
19
20#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
22pub enum CloseError {
23 TrailingBytes {
26 first_byte: u8,
28 },
29 IllegalPadding {
30 masked_padding: u8,
31 n_bits: usize,
32 },
33}
34
35impl fmt::Display for CloseError {
36 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37 match self {
38 CloseError::TrailingBytes { first_byte } => {
39 write!(f, "bitstream had trailing bytes 0x{:02x}...", first_byte)
40 }
41 CloseError::IllegalPadding {
42 masked_padding,
43 n_bits,
44 } => write!(
45 f,
46 "bitstream had {n_bits} bits in its last byte 0x{:02x}, not all zero",
47 masked_padding
48 ),
49 }
50 }
51}
52
53impl error::Error for CloseError {
54 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
55 None
56 }
57}
58
59#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
64#[allow(non_camel_case_types)]
65pub enum u2 {
66 _0,
67 _1,
68 _2,
69 _3,
70}
71
72impl From<u2> for u8 {
73 fn from(s: u2) -> u8 {
74 match s {
75 u2::_0 => 0,
76 u2::_1 => 1,
77 u2::_2 => 2,
78 u2::_3 => 3,
79 }
80 }
81}
82
83#[derive(Debug)]
86pub struct BitIter<I: Iterator<Item = u8>> {
87 iter: I,
89 cached_byte: u8,
91 read_bits: usize,
93 total_read: usize,
95}
96
97impl From<Vec<u8>> for BitIter<std::vec::IntoIter<u8>> {
98 fn from(v: Vec<u8>) -> Self {
99 BitIter {
100 iter: v.into_iter(),
101 cached_byte: 0,
102 read_bits: 8,
105 total_read: 0,
106 }
107 }
108}
109
110impl<'a> From<&'a [u8]> for BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
111 fn from(sl: &'a [u8]) -> Self {
112 BitIter {
113 iter: sl.iter().copied(),
114 cached_byte: 0,
115 read_bits: 8,
118 total_read: 0,
119 }
120 }
121}
122
123impl<I: Iterator<Item = u8>> From<I> for BitIter<I> {
124 fn from(iter: I) -> Self {
125 BitIter {
126 iter,
127 cached_byte: 0,
128 read_bits: 8,
131 total_read: 0,
132 }
133 }
134}
135
136impl<I: Iterator<Item = u8>> Iterator for BitIter<I> {
137 type Item = bool;
138
139 fn next(&mut self) -> Option<bool> {
140 if self.read_bits < 8 {
141 self.read_bits += 1;
142 self.total_read += 1;
143 Some(self.cached_byte & (1 << (8 - self.read_bits as u8)) != 0)
144 } else {
145 self.cached_byte = self.iter.next()?;
146 self.read_bits = 0;
147 self.next()
148 }
149 }
150}
151
152impl<'a> BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
153 pub fn byte_slice_window(sl: &'a [u8], start: usize, end: usize) -> Self {
156 assert!(start <= end);
157 assert!(end <= sl.len() * 8);
158
159 let actual_sl = &sl[start / 8..end.div_ceil(8)];
160 let mut iter = actual_sl.iter().copied();
161
162 let read_bits = start % 8;
163 if read_bits == 0 {
164 BitIter {
165 iter,
166 cached_byte: 0,
167 read_bits: 8,
168 total_read: 0,
169 }
170 } else {
171 BitIter {
172 cached_byte: iter.by_ref().next().unwrap(),
173 iter,
174 read_bits,
175 total_read: 0,
176 }
177 }
178 }
179}
180
181impl<I: Iterator<Item = u8>> BitIter<I> {
182 pub fn new(iter: I) -> Self {
185 Self::from(iter)
186 }
187
188 pub fn read_bit(&mut self) -> Result<bool, EarlyEndOfStreamError> {
190 self.next().ok_or(EarlyEndOfStreamError)
191 }
192
193 pub fn read_u2(&mut self) -> Result<u2, EarlyEndOfStreamError> {
195 match (self.next(), self.next()) {
196 (Some(false), Some(false)) => Ok(u2::_0),
197 (Some(false), Some(true)) => Ok(u2::_1),
198 (Some(true), Some(false)) => Ok(u2::_2),
199 (Some(true), Some(true)) => Ok(u2::_3),
200 _ => Err(EarlyEndOfStreamError),
201 }
202 }
203
204 pub fn read_u8(&mut self) -> Result<u8, EarlyEndOfStreamError> {
206 debug_assert!(self.read_bits > 0);
207 let cached = self.cached_byte;
208 self.cached_byte = self.iter.next().ok_or(EarlyEndOfStreamError)?;
209 self.total_read += 8;
210
211 Ok(cached.checked_shl(self.read_bits as u32).unwrap_or(0)
212 + (self.cached_byte >> (8 - self.read_bits)))
213 }
214
215 pub fn read_cmr(&mut self) -> Result<Cmr, EarlyEndOfStreamError> {
217 let mut ret = [0; 32];
218 for byte in &mut ret {
219 *byte = self.read_u8()?;
220 }
221 Ok(Cmr::from_byte_array(ret))
222 }
223
224 pub fn read_fail_entropy(&mut self) -> Result<FailEntropy, EarlyEndOfStreamError> {
226 let mut ret = [0; 64];
227 for byte in &mut ret {
228 *byte = self.read_u8()?;
229 }
230 Ok(FailEntropy::from_byte_array(ret))
231 }
232
233 pub fn read_natural(&mut self, bound: Option<usize>) -> Result<usize, decode::Error> {
238 decode::decode_natural(self, bound)
239 }
240
241 pub fn n_total_read(&self) -> usize {
244 self.total_read
245 }
246
247 pub fn close(mut self) -> Result<(), CloseError> {
250 if let Some(first_byte) = self.iter.next() {
251 return Err(CloseError::TrailingBytes { first_byte });
252 }
253
254 debug_assert!(self.read_bits >= 1);
255 debug_assert!(self.read_bits <= 8);
256 let n_bits = 8 - self.read_bits;
257 let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1);
258 if masked_padding != 0 {
259 Err(CloseError::IllegalPadding {
260 masked_padding,
261 n_bits,
262 })
263 } else {
264 Ok(())
265 }
266 }
267}
268
269pub trait BitCollector: Sized {
271 fn collect_bits(self) -> (Vec<u8>, usize);
273
274 fn try_collect_bytes(self) -> Result<Vec<u8>, &'static str> {
278 let (bytes, bit_length) = self.collect_bits();
279 if bit_length % 8 == 0 {
280 Ok(bytes)
281 } else {
282 Err("Number of collected bits is not divisible by 8")
283 }
284 }
285}
286
287impl<I: Iterator<Item = bool>> BitCollector for I {
288 fn collect_bits(self) -> (Vec<u8>, usize) {
289 let mut bytes = vec![];
290 let mut unfinished_byte = Vec::with_capacity(8);
291
292 for bit in self {
293 unfinished_byte.push(bit);
294
295 if unfinished_byte.len() == 8 {
296 bytes.push(
297 unfinished_byte
298 .iter()
299 .fold(0, |acc, &b| acc * 2 + u8::from(b)),
300 );
301 unfinished_byte.clear();
302 }
303 }
304
305 let bit_length = bytes.len() * 8 + unfinished_byte.len();
306
307 if !unfinished_byte.is_empty() {
308 unfinished_byte.resize(8, false);
309 bytes.push(
310 unfinished_byte
311 .iter()
312 .fold(0, |acc, &b| acc * 2 + u8::from(b)),
313 );
314 }
315
316 (bytes, bit_length)
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn empty_iter() {
326 let mut iter = BitIter::from([].iter().cloned());
327 assert!(iter.next().is_none());
328 assert_eq!(iter.read_bit(), Err(EarlyEndOfStreamError));
329 assert_eq!(iter.read_u2(), Err(EarlyEndOfStreamError));
330 assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
331 assert_eq!(iter.read_cmr(), Err(EarlyEndOfStreamError));
332 assert_eq!(iter.n_total_read(), 0);
333 }
334
335 #[test]
336 fn one_bit_iter() {
337 let mut iter = BitIter::from([0x80].iter().cloned());
338 assert_eq!(iter.read_bit(), Ok(true));
339 assert_eq!(iter.read_bit(), Ok(false));
340 assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
341 assert_eq!(iter.n_total_read(), 2);
342 }
343
344 #[test]
345 fn bit_by_bit() {
346 let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
347 for _ in 0..4 {
348 assert_eq!(iter.next(), Some(false));
349 }
350 for _ in 0..4 {
351 assert_eq!(iter.next(), Some(true));
352 }
353 for _ in 0..4 {
354 assert_eq!(iter.next(), Some(true));
355 assert_eq!(iter.next(), Some(false));
356 }
357 assert_eq!(iter.next(), None);
358 }
359
360 #[test]
361 fn byte_by_byte() {
362 let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
363 assert_eq!(iter.read_u8(), Ok(0x0f));
364 assert_eq!(iter.read_u8(), Ok(0xaa));
365 assert_eq!(iter.next(), None);
366 }
367
368 #[test]
369 fn regression_1() {
370 let mut iter = BitIter::from([0x34, 0x90].iter().cloned());
371 assert_eq!(iter.read_u2(), Ok(u2::_0)); assert_eq!(iter.read_u2(), Ok(u2::_3)); assert_eq!(iter.next(), Some(false)); assert_eq!(iter.read_u2(), Ok(u2::_2)); assert_eq!(iter.read_u2(), Ok(u2::_1)); assert_eq!(iter.n_total_read(), 9);
377 }
378
379 #[test]
380 fn byte_slice_window() {
381 let data = [0x12, 0x23, 0x34];
382
383 let mut full = BitIter::byte_slice_window(&data, 0, 24);
384 assert_eq!(full.read_u8(), Ok(0x12));
385 assert_eq!(full.n_total_read(), 8);
386 assert_eq!(full.read_u8(), Ok(0x23));
387 assert_eq!(full.n_total_read(), 16);
388 assert_eq!(full.read_u8(), Ok(0x34));
389 assert_eq!(full.n_total_read(), 24);
390 assert_eq!(full.read_u8(), Err(EarlyEndOfStreamError));
391
392 let mut mid = BitIter::byte_slice_window(&data, 8, 16);
393 assert_eq!(mid.read_u8(), Ok(0x23));
394 assert_eq!(mid.read_u8(), Err(EarlyEndOfStreamError));
395
396 let mut offs = BitIter::byte_slice_window(&data, 4, 20);
397 assert_eq!(offs.read_u8(), Ok(0x22));
398 assert_eq!(offs.read_u8(), Ok(0x33));
399 assert_eq!(offs.read_u8(), Err(EarlyEndOfStreamError));
400
401 let mut shift1 = BitIter::byte_slice_window(&data, 1, 24);
402 assert_eq!(shift1.read_u8(), Ok(0x24));
403 assert_eq!(shift1.read_u8(), Ok(0x46));
404 assert_eq!(shift1.read_u8(), Err(EarlyEndOfStreamError));
405
406 let mut shift7 = BitIter::byte_slice_window(&data, 7, 24);
407 assert_eq!(shift7.read_u8(), Ok(0x11));
408 assert_eq!(shift7.read_u8(), Ok(0x9a));
409 assert_eq!(shift7.read_u8(), Err(EarlyEndOfStreamError));
410 }
411}