bits_io/io/
bit_cursor.rs

1use std::{
2    fmt::LowerHex,
3    io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::prelude::*;
7
8#[derive(Debug, Default, Eq, PartialEq)]
9pub struct BitCursor<T> {
10    inner: T,
11    pos: u64,
12}
13
14impl<T> BitCursor<T> {
15    /// Creates a new cursor wrapping the provided buffer.
16    ///
17    /// Cursor initial position is `0` even if the given buffer is not empty.
18    pub fn new(inner: T) -> BitCursor<T> {
19        BitCursor { inner, pos: 0 }
20    }
21
22    /// Gets a mutable reference to the inner value
23    pub fn get_mut(&mut self) -> &mut T {
24        &mut self.inner
25    }
26
27    /// Gets a reference to the inner value
28    pub fn get_ref(&self) -> &T {
29        &self.inner
30    }
31
32    /// Consumes the cursor, returning the inner value.
33    pub fn into_inner(self) -> T {
34        self.inner
35    }
36
37    /// Returns the position (in _bits_ since the start) of this cursor.
38    pub fn position(&self) -> u64 {
39        self.pos
40    }
41
42    /// Sets the position of this cursor (in _bits_ since the start)
43    pub fn set_position(&mut self, pos: u64) {
44        self.pos = pos;
45    }
46}
47
48impl<T> BitCursor<T>
49where
50    T: BorrowBits,
51{
52    /// Splits the underlying slice at the cursor position and returns each half.
53    pub fn split(&self) -> (&BitSlice, &BitSlice) {
54        let bits = self.inner.borrow_bits();
55        bits.split_at(self.pos as usize)
56    }
57}
58
59impl<T> BitCursor<T>
60where
61    T: BorrowBitsMut,
62{
63    /// Splits the underlying slice at the cursor position and returns each half mutably
64    /// TODO: should we be re-exporting BitSafeU8 in some other way?
65    pub fn split_mut(
66        &mut self,
67    ) -> (
68        &mut BitSlice<bitvec::access::BitSafeU8>,
69        &mut BitSlice<bitvec::access::BitSafeU8>,
70    ) {
71        let bits = self.inner.borrow_bits_mut();
72        let (left, right) = bits.split_at_mut(self.pos as usize);
73        (left, right)
74    }
75}
76
77impl<T> Clone for BitCursor<T>
78where
79    T: Clone,
80{
81    fn clone(&self) -> Self {
82        BitCursor {
83            inner: self.inner.clone(),
84            pos: self.pos,
85        }
86    }
87}
88
89impl<T> BitSeek for BitCursor<T>
90where
91    T: BorrowBits,
92{
93    fn bit_seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
94        let (base_pos, offset) = match pos {
95            SeekFrom::Start(n) => {
96                self.pos = n;
97                return Ok(n);
98            }
99            SeekFrom::End(n) => (self.inner.borrow_bits().len() as u64, n),
100            SeekFrom::Current(n) => (self.pos, n),
101        };
102        match base_pos.checked_add_signed(offset) {
103            Some(n) => {
104                self.pos = n;
105                Ok(self.pos)
106            }
107            None => Err(std::io::Error::new(
108                std::io::ErrorKind::InvalidInput,
109                "invalid seek to a negative or overlfowing position",
110            )),
111        }
112    }
113}
114
115impl<T> Seek for BitCursor<T>
116where
117    T: BorrowBits,
118{
119    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
120        match pos {
121            SeekFrom::Start(n) => self.bit_seek(SeekFrom::Start(n * 8)),
122            SeekFrom::End(n) => self.bit_seek(SeekFrom::End(n * 8)),
123            SeekFrom::Current(n) => self.bit_seek(SeekFrom::Current(n * 8)),
124        }
125    }
126}
127
128impl<T> Read for BitCursor<T>
129where
130    T: BorrowBits,
131{
132    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
133        let bits = self.inner.borrow_bits();
134        let remaining = &bits[self.pos as usize..];
135        let mut bytes_read = 0;
136
137        for (i, chunk) in remaining.chunks(8).take(buf.len()).enumerate() {
138            let mut byte = 0u8;
139            for (j, bit) in chunk.iter().enumerate() {
140                if *bit {
141                    byte |= 1 << (7 - j);
142                }
143            }
144            buf[i] = byte;
145            bytes_read += 1;
146        }
147
148        self.pos += (bytes_read * 8) as u64;
149        Ok(bytes_read)
150    }
151}
152
153impl<T> BitRead for BitCursor<T>
154where
155    T: BorrowBits,
156{
157    fn read_bits<O: BitStore>(&mut self, dest: &mut BitSlice<O>) -> std::io::Result<usize> {
158        let n = BitRead::read_bits(&mut BitCursor::split(self).1, dest)?;
159        self.pos += n as u64;
160        Ok(n)
161    }
162}
163
164impl<T> Write for BitCursor<T>
165where
166    T: BorrowBitsMut,
167{
168    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
169        let n = Write::write(&mut BitCursor::split_mut(self).1, buf)?;
170        self.pos += (n * 8) as u64;
171        Ok(n)
172    }
173
174    fn flush(&mut self) -> std::io::Result<()> {
175        Ok(())
176    }
177}
178
179impl<T> BitWrite for BitCursor<T>
180where
181    T: BorrowBitsMut,
182    BitCursor<T>: std::io::Write,
183{
184    fn write_bits<O: BitStore>(&mut self, source: &BitSlice<O>) -> std::io::Result<usize> {
185        let n = BitWrite::write_bits(&mut BitCursor::split_mut(self).1, source)?;
186        self.pos += n as u64;
187        Ok(n)
188    }
189}
190
191impl<T> LowerHex for BitCursor<T>
192where
193    T: LowerHex,
194{
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
197    }
198}
199
200#[cfg(test)]
201mod test {
202    use std::fmt::Debug;
203    use std::io::{Seek, SeekFrom};
204
205    use crate::prelude::*;
206    use bitvec::{order::Msb0, view::BitView};
207
208    fn test_read_bits_hepler<T: BorrowBits>(buf: T, expected: &[u8]) {
209        let expected_bits = expected.view_bits::<Msb0>();
210        let mut cursor = BitCursor::new(buf);
211        let mut read_buf = bitvec![0; expected_bits.len()];
212        assert_eq!(
213            cursor.read_bits(read_buf.as_mut_bitslice()).unwrap(),
214            expected_bits.len()
215        );
216        assert_eq!(read_buf, expected_bits);
217    }
218
219    #[test]
220    fn test_read_bits() {
221        let data = [0b11110000, 0b00001111];
222
223        let vec = Vec::from(data);
224        test_read_bits_hepler(vec, &data);
225
226        let bitvec = BitVec::from_slice(&data);
227        test_read_bits_hepler(bitvec, &data);
228
229        let bitslice: &BitSlice = data.view_bits();
230        test_read_bits_hepler(bitslice, &data);
231
232        let u8_slice = &data[..];
233        test_read_bits_hepler(u8_slice, &data);
234    }
235
236    #[test]
237    fn test_read_bytes() {
238        let data = BitVec::from_vec(vec![1, 2, 3, 4]);
239        let mut cursor = BitCursor::new(data);
240
241        let mut buf = [0u8; 2];
242        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
243        assert_eq!(buf, [1, 2]);
244        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
245        assert_eq!(buf, [3, 4]);
246    }
247
248    #[test]
249    fn test_bit_seek() {
250        let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
251        let mut cursor = BitCursor::new(data);
252
253        let mut read_buf = bitvec![0; 4];
254
255        cursor.bit_seek(SeekFrom::End(-2)).expect("valid seek");
256        // Should now be reading the last 2 bits
257        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
258        assert_eq!(read_buf, bits![1, 1, 0, 0]);
259        // We already read to the end
260        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
261
262        // The read after the seek brought the cursor back to the end.  Now jump back 6 bits.
263        cursor.bit_seek(SeekFrom::Current(-6)).expect("valid seek");
264        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
265        assert_eq!(read_buf, bits![1, 1, 0, 0]);
266
267        cursor.bit_seek(SeekFrom::Start(4)).expect("valid seek");
268        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
269        assert_eq!(read_buf, bits![1, 1, 0, 0]);
270    }
271
272    #[test]
273    fn test_seek() {
274        let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
275        let mut cursor = BitCursor::new(data);
276
277        let mut read_buf = bitvec![0; 2];
278        cursor.seek(SeekFrom::End(-1)).unwrap();
279        // Should now be reading the last byte
280        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
281        assert_eq!(read_buf, bits![0, 0]);
282        // Go back one byte
283        cursor.seek(SeekFrom::Current(-1)).unwrap();
284        // We should now be in bit position 2
285        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
286        assert_eq!(read_buf, bits![0, 0]);
287    }
288
289    fn test_write_bits_helper<T>(buf: T)
290    where
291        T: BorrowBitsMut + Debug,
292        BitCursor<T>: std::io::Write,
293    {
294        let mut cursor = BitCursor::new(buf);
295        let data = bits![1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0];
296        assert_eq!(cursor.write_bits(data).unwrap(), 16);
297        assert_eq!(cursor.into_inner().borrow_bits(), data);
298    }
299
300    #[test]
301    fn test_write_bits_bitvec() {
302        let buf = BitVec::from_vec(vec![0; 2]);
303        test_write_bits_helper(buf);
304    }
305
306    #[test]
307    fn test_write_bits_vec() {
308        let buf: Vec<u8> = vec![0, 0];
309        test_write_bits_helper(buf);
310    }
311
312    #[test]
313    fn test_write_bits_bit_slice() {
314        let mut buf = bitvec![0; 16];
315        test_write_bits_helper(buf.as_mut_bitslice());
316    }
317
318    #[test]
319    fn test_write_bits_u8_slice() {
320        let mut buf = [0u8; 2];
321        test_write_bits_helper(&mut buf[..]);
322    }
323
324    fn test_split_helper<T: BorrowBits>(buf: T, expected: &[u8]) {
325        let expected_bits = expected.view_bits::<Msb0>();
326        let mut cursor = BitCursor::new(buf);
327        cursor.bit_seek(SeekFrom::Current(4)).unwrap();
328        let (before, after) = cursor.split();
329
330        assert_eq!(before, expected_bits[..4]);
331        assert_eq!(after, expected_bits[4..]);
332    }
333
334    #[test]
335    fn test_split() {
336        let data = [0b11110011, 0b10101010];
337
338        let vec = Vec::from(data);
339        test_split_helper(vec, &data);
340
341        let bitvec = BitVec::from_slice(&data);
342        test_split_helper(bitvec, &data);
343
344        let bitslice: &BitSlice = data.view_bits();
345        test_split_helper(bitslice, &data);
346
347        let u8_slice = &data[..];
348        test_split_helper(u8_slice, &data);
349    }
350
351    // Maybe a bit paranoid, but this creates cursors using different inner types, splits the data,
352    // then makes sure that cursors can be created from each split and the data read correctly
353    #[test]
354    fn test_cursors_from_splits() {
355        let data = [0b11110011, 0b10101010];
356
357        let vec = Vec::from(data);
358        let mut vec_cursor = BitCursor::new(vec);
359        vec_cursor.seek(SeekFrom::Start(1)).unwrap();
360        let (left, right) = vec_cursor.split();
361        test_read_bits_hepler(left, &data[..1]);
362        test_read_bits_hepler(right, &data[1..]);
363
364        let bitvec = BitVec::from_slice(&data);
365        let mut bitvec_cursor = BitCursor::new(bitvec);
366        bitvec_cursor.seek(SeekFrom::Start(1)).unwrap();
367        let (left, right) = bitvec_cursor.split();
368        test_read_bits_hepler(left, &data[..1]);
369        test_read_bits_hepler(right, &data[1..]);
370
371        let bitslice: &BitSlice = data.view_bits();
372        let mut bitslice_cursor = BitCursor::new(bitslice);
373        bitslice_cursor.seek(SeekFrom::Start(1)).unwrap();
374        let (left, right) = bitslice_cursor.split();
375        test_read_bits_hepler(left, &data[..1]);
376        test_read_bits_hepler(right, &data[1..]);
377
378        let u8_slice = &data[..];
379        let mut u8_cursor = BitCursor::new(u8_slice);
380        u8_cursor.seek(SeekFrom::Start(1)).unwrap();
381        let (left, right) = u8_cursor.split();
382        test_read_bits_hepler(left, &data[..1]);
383        test_read_bits_hepler(right, &data[1..]);
384    }
385
386    // Assumes the given buf is 4 bytes long
387    fn test_split_mut_helper<T>(buf: T)
388    where
389        T: BorrowBitsMut + Debug,
390        BitCursor<T>: std::io::Write,
391    {
392        let mut cursor = BitCursor::new(buf);
393        cursor.seek(SeekFrom::Start(2)).unwrap();
394
395        {
396            let (mut left, mut right) = cursor.split_mut();
397
398            left.write_bits(bits![1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0])
399                .unwrap();
400            right
401                .write_bits(bits![0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1])
402                .unwrap();
403        }
404
405        let data = cursor.into_inner();
406        assert_eq!(
407            data.borrow_bits(),
408            bits![
409                1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
410                1, 1, 1, 1
411            ]
412        );
413    }
414
415    #[test]
416    fn test_split_mut() {
417        let data = [0u8; 4];
418
419        let vec = Vec::from(data);
420        test_split_mut_helper(vec);
421
422        let bitvec = BitVec::from_vec(vec![0u8; 4]);
423        test_split_mut_helper(bitvec);
424
425        let mut data = [0u8; 4];
426        let bitslice: &mut BitSlice = data.view_bits_mut();
427        test_split_mut_helper(bitslice);
428
429        let mut data = [0u8; 4];
430        let u8_slice = &mut data[..];
431        test_split_mut_helper(u8_slice);
432    }
433
434    #[test]
435    fn test_alignment_reads_writes() {
436        for offset in 0..8 {
437            let buf = vec![0u8; 4];
438            let mut cursor = BitCursor::new(buf);
439            cursor.set_position(offset);
440            let value = BitVec::from_slice(&[0xDE, 0xAD]);
441
442            cursor.write_bits(value.as_bitslice()).unwrap();
443
444            cursor.set_position(offset);
445            let mut read_buf = BitVec::with_capacity(16);
446            read_buf.resize(16, false);
447            cursor.read_bits(read_buf.as_mut_bitslice()).unwrap();
448            assert_eq!(value, read_buf, "offset {offset}");
449        }
450    }
451}