simplicity/bit_encoding/
bititer.rs

1// SPDX-License-Identifier: CC0-1.0
2
3//! Bit Iterator functionality
4//!
5//! Simplicity programs are encoded bitwise rather than bytewise. This
6//! module provides some helper functionality to make efficient parsing
7//! easier. In particular, the `BitIter` type takes a byte iterator and
8//! wraps it with some additional functionality (including implementing
9//! `Iterator<Item=bool>`.
10//!
11
12use crate::decode;
13use crate::{Cmr, FailEntropy};
14use std::{error, fmt};
15
16/// Attempted to read from a bit iterator, but there was no more data
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
18pub struct EarlyEndOfStreamError;
19
20/// Closed out a bit iterator and there was remaining data.
21#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
22pub enum CloseError {
23    /// The iterator was closed but the underlying byte iterator was
24    /// still yielding data.
25    TrailingBytes {
26        /// The first unused byte from the iterator.
27        first_byte: u8,
28    },
29    IllegalPadding {
30        masked_padding: u8,
31        n_bits: usize,
32    },
33}
34
35impl fmt::Display for CloseError {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        match self {
38            CloseError::TrailingBytes { first_byte } => {
39                write!(f, "bitstream had trailing bytes 0x{:02x}...", first_byte)
40            }
41            CloseError::IllegalPadding {
42                masked_padding,
43                n_bits,
44            } => write!(
45                f,
46                "bitstream had {n_bits} bits in its last byte 0x{:02x}, not all zero",
47                masked_padding
48            ),
49        }
50    }
51}
52
53impl error::Error for CloseError {
54    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
55        None
56    }
57}
58
59/// Two-bit type used during decoding
60///
61/// Use of an enum rather than a numeric primitive type makes it easier to
62/// match on.
63#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
64#[allow(non_camel_case_types)]
65pub enum u2 {
66    _0,
67    _1,
68    _2,
69    _3,
70}
71
72impl From<u2> for u8 {
73    fn from(s: u2) -> u8 {
74        match s {
75            u2::_0 => 0,
76            u2::_1 => 1,
77            u2::_2 => 2,
78            u2::_3 => 3,
79        }
80    }
81}
82
83/// Bitwise iterator formed from a wrapped bytewise iterator. Bytes are
84/// interpreted big-endian, i.e. MSB is returned first
85#[derive(Debug)]
86pub struct BitIter<I: Iterator<Item = u8>> {
87    /// Byte iterator
88    iter: I,
89    /// Current byte that contains next bit
90    cached_byte: u8,
91    /// Number of read bits in current byte
92    read_bits: usize,
93    /// Total number of read bits
94    total_read: usize,
95}
96
97impl From<Vec<u8>> for BitIter<std::vec::IntoIter<u8>> {
98    fn from(v: Vec<u8>) -> Self {
99        BitIter {
100            iter: v.into_iter(),
101            cached_byte: 0,
102            // Set to 8 to force next `Iterator::next` to read a new byte
103            // from the underlying iterator
104            read_bits: 8,
105            total_read: 0,
106        }
107    }
108}
109
110impl<'a> From<&'a [u8]> for BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
111    fn from(sl: &'a [u8]) -> Self {
112        BitIter {
113            iter: sl.iter().copied(),
114            cached_byte: 0,
115            // Set to 8 to force next `Iterator::next` to read a new byte
116            // from the underlying iterator
117            read_bits: 8,
118            total_read: 0,
119        }
120    }
121}
122
123impl<I: Iterator<Item = u8>> From<I> for BitIter<I> {
124    fn from(iter: I) -> Self {
125        BitIter {
126            iter,
127            cached_byte: 0,
128            // Set to 8 to force next `Iterator::next` to read a new byte
129            // from the underlying iterator
130            read_bits: 8,
131            total_read: 0,
132        }
133    }
134}
135
136impl<I: Iterator<Item = u8>> Iterator for BitIter<I> {
137    type Item = bool;
138
139    fn next(&mut self) -> Option<bool> {
140        if self.read_bits < 8 {
141            self.read_bits += 1;
142            self.total_read += 1;
143            Some(self.cached_byte & (1 << (8 - self.read_bits as u8)) != 0)
144        } else {
145            self.cached_byte = self.iter.next()?;
146            self.read_bits = 0;
147            self.next()
148        }
149    }
150}
151
152impl<'a> BitIter<std::iter::Copied<std::slice::Iter<'a, u8>>> {
153    /// Creates a new bitwise iterator from a bytewise one. Equivalent
154    /// to using `From`
155    pub fn byte_slice_window(sl: &'a [u8], start: usize, end: usize) -> Self {
156        assert!(start <= end);
157        assert!(end <= sl.len() * 8);
158
159        let actual_sl = &sl[start / 8..end.div_ceil(8)];
160        let mut iter = actual_sl.iter().copied();
161
162        let read_bits = start % 8;
163        if read_bits == 0 {
164            BitIter {
165                iter,
166                cached_byte: 0,
167                read_bits: 8,
168                total_read: 0,
169            }
170        } else {
171            BitIter {
172                cached_byte: iter.by_ref().next().unwrap(),
173                iter,
174                read_bits,
175                total_read: 0,
176            }
177        }
178    }
179}
180
181impl<I: Iterator<Item = u8>> BitIter<I> {
182    /// Creates a new bitwise iterator from a bytewise one. Equivalent
183    /// to using `From`
184    pub fn new(iter: I) -> Self {
185        Self::from(iter)
186    }
187
188    /// Reads a single bit from the iterator.
189    pub fn read_bit(&mut self) -> Result<bool, EarlyEndOfStreamError> {
190        self.next().ok_or(EarlyEndOfStreamError)
191    }
192
193    /// Reads two bits from the iterator.
194    pub fn read_u2(&mut self) -> Result<u2, EarlyEndOfStreamError> {
195        match (self.next(), self.next()) {
196            (Some(false), Some(false)) => Ok(u2::_0),
197            (Some(false), Some(true)) => Ok(u2::_1),
198            (Some(true), Some(false)) => Ok(u2::_2),
199            (Some(true), Some(true)) => Ok(u2::_3),
200            _ => Err(EarlyEndOfStreamError),
201        }
202    }
203
204    /// Reads a byte from the iterator.
205    pub fn read_u8(&mut self) -> Result<u8, EarlyEndOfStreamError> {
206        debug_assert!(self.read_bits > 0);
207        let cached = self.cached_byte;
208        self.cached_byte = self.iter.next().ok_or(EarlyEndOfStreamError)?;
209        self.total_read += 8;
210
211        Ok(cached.checked_shl(self.read_bits as u32).unwrap_or(0)
212            + (self.cached_byte >> (8 - self.read_bits)))
213    }
214
215    /// Reads a 256-bit CMR from the iterator.
216    pub fn read_cmr(&mut self) -> Result<Cmr, EarlyEndOfStreamError> {
217        let mut ret = [0; 32];
218        for byte in &mut ret {
219            *byte = self.read_u8()?;
220        }
221        Ok(Cmr::from_byte_array(ret))
222    }
223
224    /// Reads a 512-bit fail-combinator entropy from the iterator.
225    pub fn read_fail_entropy(&mut self) -> Result<FailEntropy, EarlyEndOfStreamError> {
226        let mut ret = [0; 64];
227        for byte in &mut ret {
228            *byte = self.read_u8()?;
229        }
230        Ok(FailEntropy::from_byte_array(ret))
231    }
232
233    /// Decode a natural number from bits.
234    ///
235    /// If a bound is specified, then the decoding terminates before trying to
236    /// decode a larger number.
237    pub fn read_natural(&mut self, bound: Option<usize>) -> Result<usize, decode::Error> {
238        decode::decode_natural(self, bound)
239    }
240
241    /// Accessor for the number of bits which have been read,
242    /// in total, from this iterator
243    pub fn n_total_read(&self) -> usize {
244        self.total_read
245    }
246
247    /// Consumes the bit iterator, checking that there are no remaining
248    /// bytes and that any unread bits are zero.
249    pub fn close(mut self) -> Result<(), CloseError> {
250        if let Some(first_byte) = self.iter.next() {
251            return Err(CloseError::TrailingBytes { first_byte });
252        }
253
254        debug_assert!(self.read_bits >= 1);
255        debug_assert!(self.read_bits <= 8);
256        let n_bits = 8 - self.read_bits;
257        let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1);
258        if masked_padding != 0 {
259            Err(CloseError::IllegalPadding {
260                masked_padding,
261                n_bits,
262            })
263        } else {
264            Ok(())
265        }
266    }
267}
268
269/// Functionality for Boolean iterators to collect their bits or bytes.
270pub trait BitCollector: Sized {
271    /// Collect the bits of the iterator into a byte vector and a bit length.
272    fn collect_bits(self) -> (Vec<u8>, usize);
273
274    /// Try to collect the bits of the iterator into a clean byte vector.
275    ///
276    /// This fails if the number of bits is not divisible by 8.
277    fn try_collect_bytes(self) -> Result<Vec<u8>, &'static str> {
278        let (bytes, bit_length) = self.collect_bits();
279        if bit_length % 8 == 0 {
280            Ok(bytes)
281        } else {
282            Err("Number of collected bits is not divisible by 8")
283        }
284    }
285}
286
287impl<I: Iterator<Item = bool>> BitCollector for I {
288    fn collect_bits(self) -> (Vec<u8>, usize) {
289        let mut bytes = vec![];
290        let mut unfinished_byte = Vec::with_capacity(8);
291
292        for bit in self {
293            unfinished_byte.push(bit);
294
295            if unfinished_byte.len() == 8 {
296                bytes.push(
297                    unfinished_byte
298                        .iter()
299                        .fold(0, |acc, &b| acc * 2 + u8::from(b)),
300                );
301                unfinished_byte.clear();
302            }
303        }
304
305        let bit_length = bytes.len() * 8 + unfinished_byte.len();
306
307        if !unfinished_byte.is_empty() {
308            unfinished_byte.resize(8, false);
309            bytes.push(
310                unfinished_byte
311                    .iter()
312                    .fold(0, |acc, &b| acc * 2 + u8::from(b)),
313            );
314        }
315
316        (bytes, bit_length)
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn empty_iter() {
326        let mut iter = BitIter::from([].iter().cloned());
327        assert!(iter.next().is_none());
328        assert_eq!(iter.read_bit(), Err(EarlyEndOfStreamError));
329        assert_eq!(iter.read_u2(), Err(EarlyEndOfStreamError));
330        assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
331        assert_eq!(iter.read_cmr(), Err(EarlyEndOfStreamError));
332        assert_eq!(iter.n_total_read(), 0);
333    }
334
335    #[test]
336    fn one_bit_iter() {
337        let mut iter = BitIter::from([0x80].iter().cloned());
338        assert_eq!(iter.read_bit(), Ok(true));
339        assert_eq!(iter.read_bit(), Ok(false));
340        assert_eq!(iter.read_u8(), Err(EarlyEndOfStreamError));
341        assert_eq!(iter.n_total_read(), 2);
342    }
343
344    #[test]
345    fn bit_by_bit() {
346        let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
347        for _ in 0..4 {
348            assert_eq!(iter.next(), Some(false));
349        }
350        for _ in 0..4 {
351            assert_eq!(iter.next(), Some(true));
352        }
353        for _ in 0..4 {
354            assert_eq!(iter.next(), Some(true));
355            assert_eq!(iter.next(), Some(false));
356        }
357        assert_eq!(iter.next(), None);
358    }
359
360    #[test]
361    fn byte_by_byte() {
362        let mut iter = BitIter::from([0x0f, 0xaa].iter().cloned());
363        assert_eq!(iter.read_u8(), Ok(0x0f));
364        assert_eq!(iter.read_u8(), Ok(0xaa));
365        assert_eq!(iter.next(), None);
366    }
367
368    #[test]
369    fn regression_1() {
370        let mut iter = BitIter::from([0x34, 0x90].iter().cloned());
371        assert_eq!(iter.read_u2(), Ok(u2::_0)); // 0011
372        assert_eq!(iter.read_u2(), Ok(u2::_3)); // 0011
373        assert_eq!(iter.next(), Some(false)); // 0
374        assert_eq!(iter.read_u2(), Ok(u2::_2)); // 10
375        assert_eq!(iter.read_u2(), Ok(u2::_1)); // 01
376        assert_eq!(iter.n_total_read(), 9);
377    }
378
379    #[test]
380    fn byte_slice_window() {
381        let data = [0x12, 0x23, 0x34];
382
383        let mut full = BitIter::byte_slice_window(&data, 0, 24);
384        assert_eq!(full.read_u8(), Ok(0x12));
385        assert_eq!(full.n_total_read(), 8);
386        assert_eq!(full.read_u8(), Ok(0x23));
387        assert_eq!(full.n_total_read(), 16);
388        assert_eq!(full.read_u8(), Ok(0x34));
389        assert_eq!(full.n_total_read(), 24);
390        assert_eq!(full.read_u8(), Err(EarlyEndOfStreamError));
391
392        let mut mid = BitIter::byte_slice_window(&data, 8, 16);
393        assert_eq!(mid.read_u8(), Ok(0x23));
394        assert_eq!(mid.read_u8(), Err(EarlyEndOfStreamError));
395
396        let mut offs = BitIter::byte_slice_window(&data, 4, 20);
397        assert_eq!(offs.read_u8(), Ok(0x22));
398        assert_eq!(offs.read_u8(), Ok(0x33));
399        assert_eq!(offs.read_u8(), Err(EarlyEndOfStreamError));
400
401        let mut shift1 = BitIter::byte_slice_window(&data, 1, 24);
402        assert_eq!(shift1.read_u8(), Ok(0x24));
403        assert_eq!(shift1.read_u8(), Ok(0x46));
404        assert_eq!(shift1.read_u8(), Err(EarlyEndOfStreamError));
405
406        let mut shift7 = BitIter::byte_slice_window(&data, 7, 24);
407        assert_eq!(shift7.read_u8(), Ok(0x11));
408        assert_eq!(shift7.read_u8(), Ok(0x9a));
409        assert_eq!(shift7.read_u8(), Err(EarlyEndOfStreamError));
410    }
411}