bits_io/
bit_cursor.rs

1use std::{
2    fmt::LowerHex,
3    io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7    bit_read::BitRead,
8    bit_seek::BitSeek,
9    bit_write::BitWrite,
10    borrow_bits::{BorrowBits, BorrowBitsMut},
11    prelude::*,
12};
13
14#[derive(Debug, Default, Eq, PartialEq)]
15pub struct BitCursor<T> {
16    inner: T,
17    pos: u64,
18}
19
20impl<T> BitCursor<T> {
21    /// Creates a new cursor wrapping the provided buffer.
22    ///
23    /// Cursor initial position is `0` even if the given buffer is not empty.
24    pub fn new(inner: T) -> BitCursor<T> {
25        BitCursor { inner, pos: 0 }
26    }
27
28    /// Gets a mutable reference to the inner value
29    pub fn get_mut(&mut self) -> &mut T {
30        &mut self.inner
31    }
32
33    /// Gets a reference to the inner value
34    pub fn get_ref(&self) -> &T {
35        &self.inner
36    }
37
38    /// Consumes the cursor, returning the inner value.
39    pub fn into_inner(self) -> T {
40        self.inner
41    }
42
43    /// Returns the position (in _bits_ since the start) of this cursor.
44    pub fn position(&self) -> u64 {
45        self.pos
46    }
47
48    /// Sets the position of this cursor (in _bits_ since the start)
49    pub fn set_position(&mut self, pos: u64) {
50        self.pos = pos;
51    }
52}
53
54impl<T> BitCursor<T>
55where
56    T: BorrowBits,
57{
58    pub fn split(&self) -> (&BitSlice, &BitSlice) {
59        let bits = self.inner.borrow_bits();
60        bits.split_at(self.pos as usize)
61    }
62}
63
64impl<T> BitCursor<T>
65where
66    T: BorrowBitsMut,
67{
68    pub fn split_mut(&mut self) -> (&mut BitSlice<BitSafeU8>, &mut BitSlice<BitSafeU8>) {
69        let bits = self.inner.borrow_bits_mut();
70        let (left, right) = bits.split_at_mut(self.pos as usize);
71        (left, right)
72    }
73}
74
75impl<T> Clone for BitCursor<T>
76where
77    T: Clone,
78{
79    fn clone(&self) -> Self {
80        BitCursor {
81            inner: self.inner.clone(),
82            pos: self.pos,
83        }
84    }
85}
86
87impl<T> BitSeek for BitCursor<T>
88where
89    T: BorrowBits,
90{
91    fn bit_seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
92        let (base_pos, offset) = match pos {
93            SeekFrom::Start(n) => {
94                self.pos = n;
95                return Ok(n);
96            }
97            SeekFrom::End(n) => (self.inner.borrow_bits().len() as u64, n),
98            SeekFrom::Current(n) => (self.pos, n),
99        };
100        match base_pos.checked_add_signed(offset) {
101            Some(n) => {
102                self.pos = n;
103                Ok(self.pos)
104            }
105            None => Err(std::io::Error::new(
106                std::io::ErrorKind::InvalidInput,
107                "invalid seek to a negative or overlfowing position",
108            )),
109        }
110    }
111}
112
113impl<T> Seek for BitCursor<T>
114where
115    T: BorrowBits,
116{
117    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
118        match pos {
119            SeekFrom::Start(n) => self.bit_seek(SeekFrom::Start(n * 8)),
120            SeekFrom::End(n) => self.bit_seek(SeekFrom::End(n * 8)),
121            SeekFrom::Current(n) => self.bit_seek(SeekFrom::Current(n * 8)),
122        }
123    }
124}
125
126impl<T> Read for BitCursor<T>
127where
128    T: BorrowBits,
129{
130    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
131        let bits = self.inner.borrow_bits();
132        let remaining = &bits[self.pos as usize..];
133        let mut bytes_read = 0;
134
135        for (i, chunk) in remaining.chunks(8).take(buf.len()).enumerate() {
136            let mut byte = 0u8;
137            for (j, bit) in chunk.iter().enumerate() {
138                if *bit {
139                    byte |= 1 << (7 - j);
140                }
141            }
142            buf[i] = byte;
143            bytes_read += 1;
144        }
145
146        self.pos += (bytes_read * 8) as u64;
147        Ok(bytes_read)
148    }
149}
150
151impl<T> BitRead for BitCursor<T>
152where
153    T: BorrowBits,
154{
155    fn read_bits(&mut self, dest: &mut BitSlice) -> std::io::Result<usize> {
156        let n = BitRead::read_bits(&mut BitCursor::split(self).1, dest)?;
157        self.pos += n as u64;
158        Ok(n)
159    }
160}
161
162impl<T> Write for BitCursor<T>
163where
164    T: BorrowBitsMut,
165{
166    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167        let n = Write::write(&mut BitCursor::split_mut(self).1, buf)?;
168        self.pos += (n * 8) as u64;
169        Ok(n)
170    }
171
172    fn flush(&mut self) -> std::io::Result<()> {
173        Ok(())
174    }
175}
176
177impl<T> BitWrite for BitCursor<T>
178where
179    T: BorrowBitsMut,
180    BitCursor<T>: std::io::Write,
181{
182    fn write_bits<O: BitStore>(&mut self, source: &BitSlice<O>) -> std::io::Result<usize> {
183        let n = BitWrite::write_bits(&mut BitCursor::split_mut(self).1, source)?;
184        self.pos += n as u64;
185        Ok(n)
186    }
187}
188
189impl<T> LowerHex for BitCursor<T>
190where
191    T: LowerHex,
192{
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
195    }
196}
197
198#[cfg(test)]
199mod test {
200    use std::fmt::Debug;
201    use std::io::{Seek, SeekFrom};
202
203    use crate::prelude::*;
204    use bitvec::bits;
205    use bitvec::bitvec;
206    use bitvec::view::BitView;
207    use nsw_types::*;
208
209    use crate::bit_read::BitRead;
210    use crate::bit_read_exts::BitReadExts;
211    use crate::bit_seek::BitSeek;
212    use crate::bit_write_exts::BitWriteExts;
213    use crate::borrow_bits::{BorrowBits, BorrowBitsMut};
214    use crate::byte_order::NetworkOrder;
215
216    use super::BitCursor;
217
218    fn test_read_bits_hepler<T: BorrowBits>(buf: T, expected: &[u8]) {
219        let expected_bits = expected.view_bits::<Msb0>();
220        let mut cursor = BitCursor::new(buf);
221        let mut read_buf = bitvec![u8, Msb0; 0; expected_bits.len()];
222        assert_eq!(
223            cursor.read_bits(&mut read_buf).unwrap(),
224            expected_bits.len()
225        );
226        assert_eq!(read_buf, expected_bits);
227    }
228
229    #[test]
230    fn test_read_bits() {
231        let data = [0b11110000, 0b00001111];
232
233        let vec = Vec::from(data);
234        test_read_bits_hepler(vec, &data);
235
236        let bitvec = BitVec::from_slice(&data);
237        test_read_bits_hepler(bitvec, &data);
238
239        let bitslice: &BitSlice = data.view_bits();
240        test_read_bits_hepler(bitslice, &data);
241
242        let u8_slice = &data[..];
243        test_read_bits_hepler(u8_slice, &data);
244    }
245
246    #[test]
247    fn test_read_bytes() {
248        let data = BitVec::from_vec(vec![1, 2, 3, 4]);
249        let mut cursor = BitCursor::new(data);
250
251        let mut buf = [0u8; 2];
252        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
253        assert_eq!(buf, [1, 2]);
254        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
255        assert_eq!(buf, [3, 4]);
256    }
257
258    #[test]
259    fn test_bit_seek() {
260        let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
261        let mut cursor = BitCursor::new(data);
262
263        let mut read_buf = bitvec![u8, Msb0; 0; 4];
264
265        cursor.bit_seek(SeekFrom::End(-2)).expect("valid seek");
266        // Should now be reading the last 2 bits
267        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
268        assert_eq!(read_buf, bits![u8, Msb0; 1, 1, 0, 0]);
269        // We already read to the end
270        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
271
272        // The read after the seek brought the cursor back to the end.  Now jump back 6 bits.
273        cursor.bit_seek(SeekFrom::Current(-6)).expect("valid seek");
274        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
275        assert_eq!(read_buf, bits![u8, Msb0; 1, 1, 0, 0]);
276
277        cursor.bit_seek(SeekFrom::Start(4)).expect("valid seek");
278        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
279        assert_eq!(read_buf, bits![u8, Msb0; 1, 1, 0, 0]);
280    }
281
282    #[test]
283    fn test_seek() {
284        let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
285        let mut cursor = BitCursor::new(data);
286
287        let mut read_buf = bitvec![u8, Msb0; 0; 2];
288        cursor.seek(SeekFrom::End(-1)).unwrap();
289        // Should now be reading the last byte
290        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
291        assert_eq!(read_buf, bits![u8, Msb0; 0, 0]);
292        // Go back one byte
293        cursor.seek(SeekFrom::Current(-1)).unwrap();
294        // We should now be in bit position 2
295        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
296        assert_eq!(read_buf, bits![u8, Msb0; 0, 0]);
297    }
298
299    fn test_write_bits_helper<T: BorrowBitsMut>(buf: T) -> T {
300        let mut cursor = BitCursor::new(buf);
301        cursor.write_u4(u4::new(0b1100)).unwrap();
302        cursor.write_u2(u2::new(0b11)).unwrap();
303        cursor.write_u2(u2::new(0b00)).unwrap();
304        cursor.write_u3(u3::new(0b110)).unwrap();
305        cursor.write_u5(u5::new(0b01100)).unwrap();
306        cursor.into_inner()
307    }
308
309    #[test]
310    fn test_write_bits_bitvec() {
311        let buf = BitVec::from_vec(vec![0; 2]);
312
313        assert_eq!(
314            test_write_bits_helper(buf),
315            BitVec::from_vec(vec![0b11001100, 0b11001100])
316        );
317    }
318
319    #[test]
320    fn test_write_bits_vec() {
321        let buf: Vec<u8> = vec![0, 0];
322
323        assert_eq!(test_write_bits_helper(buf), [0b11001100, 0b11001100]);
324    }
325
326    #[test]
327    fn test_write_bits_bit_slice() {
328        let mut data = [0u8; 2];
329        let buf: &mut BitSlice = data.view_bits_mut::<Msb0>();
330
331        assert_eq!(
332            test_write_bits_helper(buf),
333            BitVec::from_vec(vec![0b11001100, 0b11001100]).as_bitslice()
334        );
335    }
336
337    #[test]
338    fn test_write_bits_u8_slice() {
339        let mut buf = [0u8; 2];
340
341        assert_eq!(
342            test_write_bits_helper(&mut buf[..]),
343            [0b11001100, 0b11001100]
344        );
345    }
346
347    fn test_split_helper<T: BorrowBits>(buf: T, expected: &[u8]) {
348        let expected_bits = expected.view_bits::<Msb0>();
349        let mut cursor = BitCursor::new(buf);
350        cursor.bit_seek(SeekFrom::Current(4)).unwrap();
351        let (before, after) = cursor.split();
352
353        assert_eq!(before, expected_bits[..4]);
354        assert_eq!(after, expected_bits[4..]);
355    }
356
357    #[test]
358    fn test_split() {
359        let data = [0b11110011, 0b10101010];
360
361        let vec = Vec::from(data);
362        test_split_helper(vec, &data);
363
364        let bitvec = BitVec::from_slice(&data);
365        test_split_helper(bitvec, &data);
366
367        let bitslice: &BitSlice = data.view_bits();
368        test_split_helper(bitslice, &data);
369
370        let u8_slice = &data[..];
371        test_split_helper(u8_slice, &data);
372    }
373
374    // Maybe a bit paranoid, but this creates cursors using different inner types, splits the data,
375    // then makes sure that cursors can be created from each split and the data read correctly
376    #[test]
377    fn test_cursors_from_splits() {
378        let data = [0b11110011, 0b10101010];
379
380        let vec = Vec::from(data);
381        let mut vec_cursor = BitCursor::new(vec);
382        vec_cursor.seek(SeekFrom::Start(1)).unwrap();
383        let (left, right) = vec_cursor.split();
384        test_read_bits_hepler(left, &data[..1]);
385        test_read_bits_hepler(right, &data[1..]);
386
387        let bitvec = BitVec::from_slice(&data);
388        let mut bitvec_cursor = BitCursor::new(bitvec);
389        bitvec_cursor.seek(SeekFrom::Start(1)).unwrap();
390        let (left, right) = bitvec_cursor.split();
391        test_read_bits_hepler(left, &data[..1]);
392        test_read_bits_hepler(right, &data[1..]);
393
394        let bitslice: &BitSlice = data.view_bits();
395        let mut bitslice_cursor = BitCursor::new(bitslice);
396        bitslice_cursor.seek(SeekFrom::Start(1)).unwrap();
397        let (left, right) = bitslice_cursor.split();
398        test_read_bits_hepler(left, &data[..1]);
399        test_read_bits_hepler(right, &data[1..]);
400
401        let u8_slice = &data[..];
402        let mut u8_cursor = BitCursor::new(u8_slice);
403        u8_cursor.seek(SeekFrom::Start(1)).unwrap();
404        let (left, right) = u8_cursor.split();
405        test_read_bits_hepler(left, &data[..1]);
406        test_read_bits_hepler(right, &data[1..]);
407    }
408
409    // Assumes the given buf is 4 bytes long
410    fn test_split_mut_helper<T, U, F>(buf: T, create_expected: F)
411    where
412        T: BorrowBitsMut + PartialEq<U> + Debug,
413        U: Debug,
414        F: FnOnce(&[u8]) -> U,
415    {
416        let mut cursor = BitCursor::new(buf);
417        cursor.seek(SeekFrom::Start(2)).unwrap();
418        {
419            let (mut before, mut after) = cursor.split_mut();
420
421            before
422                .write_u16::<NetworkOrder>(0b1111111100000000)
423                .unwrap();
424            after.write_u16::<NetworkOrder>(0b1100110000110011).unwrap();
425        }
426
427        let data = cursor.into_inner();
428        let expected = create_expected(&[0b11111111, 0b00000000, 0b11001100, 0b00110011]);
429        assert_eq!(data, expected);
430    }
431
432    #[test]
433    fn test_split_mut() {
434        let data = [0u8; 4];
435
436        let vec = Vec::from(data);
437        test_split_mut_helper(vec, |v| v.to_vec());
438
439        let bitvec = BitVec::from_vec(vec![0u8; 4]);
440        test_split_mut_helper(bitvec, |v| BitVec::from_vec(v.to_vec()));
441
442        let mut data = [0u8; 4];
443        let bitslice: &mut BitSlice = data.view_bits_mut();
444        test_split_mut_helper(bitslice, |v| BitVec::from_vec(v.to_vec()));
445
446        let mut data = [0u8; 4];
447        let u8_slice = &mut data[..];
448        test_split_mut_helper(u8_slice, |v| v.to_vec());
449    }
450
451    #[test]
452    fn test_alignment_reads_writes() {
453        for offset in 0..8 {
454            let buf = vec![0u8; 4];
455            let mut cursor = BitCursor::new(buf);
456            cursor.set_position(offset);
457            let value = 0xDEADu16;
458            cursor.write_u16::<BigEndian>(value).unwrap();
459            cursor.set_position(offset);
460            let read_value = cursor.read_u16::<BigEndian>().unwrap();
461            assert_eq!(value, read_value, "offset {offset}");
462        }
463    }
464}