bits_io/
bit_cursor.rs

1use std::{
2    fmt::LowerHex,
3    io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7    bit_read::BitRead,
8    bit_seek::BitSeek,
9    bit_write::BitWrite,
10    borrow_bits::{BorrowBits, BorrowBitsMut},
11    prelude::*,
12};
13
14#[derive(Debug, Default, Eq, PartialEq)]
15pub struct BitCursor<T> {
16    inner: T,
17    pos: u64,
18}
19
20impl<T> BitCursor<T> {
21    /// Creates a new cursor wrapping the provided buffer.
22    ///
23    /// Cursor initial position is `0` even if the given buffer is not empty.
24    pub fn new(inner: T) -> BitCursor<T> {
25        BitCursor { inner, pos: 0 }
26    }
27
28    /// Gets a mutable reference to the inner value
29    pub fn get_mut(&mut self) -> &mut T {
30        &mut self.inner
31    }
32
33    /// Gets a reference to the inner value
34    pub fn get_ref(&self) -> &T {
35        &self.inner
36    }
37
38    /// Consumes the cursor, returning the inner value.
39    pub fn into_inner(self) -> T {
40        self.inner
41    }
42
43    /// Returns the position (in _bits_ since the start) of this cursor.
44    pub fn position(&self) -> u64 {
45        self.pos
46    }
47
48    /// Sets the position of this cursor (in _bits_ since the start)
49    pub fn set_position(&mut self, pos: u64) {
50        self.pos = pos;
51    }
52}
53
54impl<T> BitCursor<T>
55where
56    T: BorrowBits,
57{
58    pub fn split(&self) -> (&BitSlice, &BitSlice) {
59        let bits = self.inner.borrow_bits();
60        bits.split_at(self.pos as usize)
61    }
62}
63
64impl<T> BitCursor<T>
65where
66    T: BorrowBitsMut,
67{
68    pub fn split_mut(&mut self) -> (&mut BitSlice<impl BitStore>, &mut BitSlice<impl BitStore>) {
69        let bits = self.inner.borrow_bits_mut();
70        let (left, right) = bits.split_at_mut(self.pos as usize);
71        (left, right)
72    }
73}
74
75impl<T> Clone for BitCursor<T>
76where
77    T: Clone,
78{
79    fn clone(&self) -> Self {
80        BitCursor {
81            inner: self.inner.clone(),
82            pos: self.pos,
83        }
84    }
85}
86
87impl<T> BitSeek for BitCursor<T>
88where
89    T: BorrowBits,
90{
91    fn bit_seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
92        let (base_pos, offset) = match pos {
93            SeekFrom::Start(n) => {
94                self.pos = n;
95                return Ok(n);
96            }
97            SeekFrom::End(n) => (self.inner.borrow_bits().len() as u64, n),
98            SeekFrom::Current(n) => (self.pos, n),
99        };
100        match base_pos.checked_add_signed(offset) {
101            Some(n) => {
102                self.pos = n;
103                Ok(self.pos)
104            }
105            None => Err(std::io::Error::new(
106                std::io::ErrorKind::InvalidInput,
107                "invalid seek to a negative or overlfowing position",
108            )),
109        }
110    }
111}
112
113impl<T> Seek for BitCursor<T>
114where
115    T: BorrowBits,
116{
117    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
118        match pos {
119            SeekFrom::Start(n) => self.bit_seek(SeekFrom::Start(n * 8)),
120            SeekFrom::End(n) => self.bit_seek(SeekFrom::End(n * 8)),
121            SeekFrom::Current(n) => self.bit_seek(SeekFrom::Current(n * 8)),
122        }
123    }
124}
125
126impl<T> Read for BitCursor<T>
127where
128    T: BorrowBits,
129{
130    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
131        let bits = self.inner.borrow_bits();
132        let remaining = &bits[self.pos as usize..];
133        let mut bytes_read = 0;
134
135        for (i, chunk) in remaining.chunks(8).take(buf.len()).enumerate() {
136            let mut byte = 0u8;
137            for (j, bit) in chunk.iter().enumerate() {
138                if *bit {
139                    byte |= 1 << (7 - j);
140                }
141            }
142            buf[i] = byte;
143            bytes_read += 1;
144        }
145
146        self.pos += (bytes_read * 8) as u64;
147        Ok(bytes_read)
148    }
149}
150
151impl<T> BitRead for BitCursor<T>
152where
153    T: BorrowBits,
154{
155    fn read_bits(&mut self, dest: &mut BitSlice) -> std::io::Result<usize> {
156        let n = BitRead::read_bits(&mut BitCursor::split(self).1, dest)?;
157        self.pos += n as u64;
158        Ok(n)
159    }
160}
161
162impl<T> Write for BitCursor<T>
163where
164    T: BorrowBitsMut,
165{
166    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167        let n = Write::write(&mut BitCursor::split_mut(self).1, buf)?;
168        self.pos += (n * 8) as u64;
169        Ok(n)
170    }
171
172    fn flush(&mut self) -> std::io::Result<()> {
173        Ok(())
174    }
175}
176
177impl<T> BitWrite for BitCursor<T>
178where
179    T: BorrowBitsMut,
180    BitCursor<T>: std::io::Write,
181{
182    fn write_bits<O: BitStore>(&mut self, source: &BitSlice<O>) -> std::io::Result<usize> {
183        let n = BitWrite::write_bits(&mut BitCursor::split_mut(self).1, source)?;
184        self.pos += n as u64;
185        Ok(n)
186    }
187}
188
189impl<T> LowerHex for BitCursor<T>
190where
191    T: LowerHex,
192{
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
195    }
196}
197
198#[cfg(test)]
199mod test {
200    use std::fmt::Debug;
201    use std::io::{Seek, SeekFrom};
202
203    use crate::prelude::*;
204    use bitvec::{order::Msb0, view::BitView};
205    use nsw_types::*;
206
207    fn test_read_bits_hepler<T: BorrowBits>(buf: T, expected: &[u8]) {
208        let expected_bits = expected.view_bits::<Msb0>();
209        let mut cursor = BitCursor::new(buf);
210        let mut read_buf = bitvec![0; expected_bits.len()];
211        assert_eq!(
212            cursor.read_bits(read_buf.as_mut_bitslice()).unwrap(),
213            expected_bits.len()
214        );
215        assert_eq!(read_buf, expected_bits);
216    }
217
218    #[test]
219    fn test_read_bits() {
220        let data = [0b11110000, 0b00001111];
221
222        let vec = Vec::from(data);
223        test_read_bits_hepler(vec, &data);
224
225        let bitvec = BitVec::from_slice(&data);
226        test_read_bits_hepler(bitvec, &data);
227
228        let bitslice: &BitSlice = data.view_bits();
229        test_read_bits_hepler(bitslice, &data);
230
231        let u8_slice = &data[..];
232        test_read_bits_hepler(u8_slice, &data);
233    }
234
235    #[test]
236    fn test_read_bytes() {
237        let data = BitVec::from_vec(vec![1, 2, 3, 4]);
238        let mut cursor = BitCursor::new(data);
239
240        let mut buf = [0u8; 2];
241        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
242        assert_eq!(buf, [1, 2]);
243        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
244        assert_eq!(buf, [3, 4]);
245    }
246
247    #[test]
248    fn test_bit_seek() {
249        let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
250        let mut cursor = BitCursor::new(data);
251
252        let mut read_buf = bitvec![0; 4];
253
254        cursor.bit_seek(SeekFrom::End(-2)).expect("valid seek");
255        // Should now be reading the last 2 bits
256        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
257        assert_eq!(read_buf, bits![1, 1, 0, 0]);
258        // We already read to the end
259        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
260
261        // The read after the seek brought the cursor back to the end.  Now jump back 6 bits.
262        cursor.bit_seek(SeekFrom::Current(-6)).expect("valid seek");
263        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
264        assert_eq!(read_buf, bits![1, 1, 0, 0]);
265
266        cursor.bit_seek(SeekFrom::Start(4)).expect("valid seek");
267        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
268        assert_eq!(read_buf, bits![1, 1, 0, 0]);
269    }
270
271    #[test]
272    fn test_seek() {
273        let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
274        let mut cursor = BitCursor::new(data);
275
276        let mut read_buf = bitvec![0; 2];
277        cursor.seek(SeekFrom::End(-1)).unwrap();
278        // Should now be reading the last byte
279        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
280        assert_eq!(read_buf, bits![0, 0]);
281        // Go back one byte
282        cursor.seek(SeekFrom::Current(-1)).unwrap();
283        // We should now be in bit position 2
284        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
285        assert_eq!(read_buf, bits![0, 0]);
286    }
287
288    fn test_write_bits_helper<T: BorrowBitsMut>(buf: T) -> T {
289        let mut cursor = BitCursor::new(buf);
290        cursor.write_u4(u4::new(0b1100)).unwrap();
291        cursor.write_u2(u2::new(0b11)).unwrap();
292        cursor.write_u2(u2::new(0b00)).unwrap();
293        cursor.write_u3(u3::new(0b110)).unwrap();
294        cursor.write_u5(u5::new(0b01100)).unwrap();
295        cursor.into_inner()
296    }
297
298    #[test]
299    fn test_write_bits_bitvec() {
300        let buf = BitVec::from_vec(vec![0; 2]);
301
302        assert_eq!(
303            test_write_bits_helper(buf),
304            BitVec::from_vec(vec![0b11001100, 0b11001100])
305        );
306    }
307
308    #[test]
309    fn test_write_bits_vec() {
310        let buf: Vec<u8> = vec![0, 0];
311
312        assert_eq!(test_write_bits_helper(buf), [0b11001100, 0b11001100]);
313    }
314
315    #[test]
316    fn test_write_bits_bit_slice() {
317        let mut data = [0u8; 2];
318        let buf: &mut BitSlice = data.view_bits_mut::<Msb0>();
319
320        assert_eq!(
321            test_write_bits_helper(buf),
322            BitVec::from_vec(vec![0b11001100, 0b11001100]).as_bitslice()
323        );
324    }
325
326    #[test]
327    fn test_write_bits_u8_slice() {
328        let mut buf = [0u8; 2];
329
330        assert_eq!(
331            test_write_bits_helper(&mut buf[..]),
332            [0b11001100, 0b11001100]
333        );
334    }
335
336    fn test_split_helper<T: BorrowBits>(buf: T, expected: &[u8]) {
337        let expected_bits = expected.view_bits::<Msb0>();
338        let mut cursor = BitCursor::new(buf);
339        cursor.bit_seek(SeekFrom::Current(4)).unwrap();
340        let (before, after) = cursor.split();
341
342        assert_eq!(before, expected_bits[..4]);
343        assert_eq!(after, expected_bits[4..]);
344    }
345
346    #[test]
347    fn test_split() {
348        let data = [0b11110011, 0b10101010];
349
350        let vec = Vec::from(data);
351        test_split_helper(vec, &data);
352
353        let bitvec = BitVec::from_slice(&data);
354        test_split_helper(bitvec, &data);
355
356        let bitslice: &BitSlice = data.view_bits();
357        test_split_helper(bitslice, &data);
358
359        let u8_slice = &data[..];
360        test_split_helper(u8_slice, &data);
361    }
362
363    // Maybe a bit paranoid, but this creates cursors using different inner types, splits the data,
364    // then makes sure that cursors can be created from each split and the data read correctly
365    #[test]
366    fn test_cursors_from_splits() {
367        let data = [0b11110011, 0b10101010];
368
369        let vec = Vec::from(data);
370        let mut vec_cursor = BitCursor::new(vec);
371        vec_cursor.seek(SeekFrom::Start(1)).unwrap();
372        let (left, right) = vec_cursor.split();
373        test_read_bits_hepler(left, &data[..1]);
374        test_read_bits_hepler(right, &data[1..]);
375
376        let bitvec = BitVec::from_slice(&data);
377        let mut bitvec_cursor = BitCursor::new(bitvec);
378        bitvec_cursor.seek(SeekFrom::Start(1)).unwrap();
379        let (left, right) = bitvec_cursor.split();
380        test_read_bits_hepler(left, &data[..1]);
381        test_read_bits_hepler(right, &data[1..]);
382
383        let bitslice: &BitSlice = data.view_bits();
384        let mut bitslice_cursor = BitCursor::new(bitslice);
385        bitslice_cursor.seek(SeekFrom::Start(1)).unwrap();
386        let (left, right) = bitslice_cursor.split();
387        test_read_bits_hepler(left, &data[..1]);
388        test_read_bits_hepler(right, &data[1..]);
389
390        let u8_slice = &data[..];
391        let mut u8_cursor = BitCursor::new(u8_slice);
392        u8_cursor.seek(SeekFrom::Start(1)).unwrap();
393        let (left, right) = u8_cursor.split();
394        test_read_bits_hepler(left, &data[..1]);
395        test_read_bits_hepler(right, &data[1..]);
396    }
397
398    // Assumes the given buf is 4 bytes long
399    fn test_split_mut_helper<T, U, F>(buf: T, create_expected: F)
400    where
401        T: BorrowBitsMut + PartialEq<U> + Debug,
402        U: Debug,
403        F: FnOnce(&[u8]) -> U,
404    {
405        let mut cursor = BitCursor::new(buf);
406        cursor.seek(SeekFrom::Start(2)).unwrap();
407        {
408            let (mut before, mut after) = cursor.split_mut();
409
410            before
411                .write_u16::<NetworkOrder>(0b1111111100000000)
412                .unwrap();
413            after.write_u16::<NetworkOrder>(0b1100110000110011).unwrap();
414        }
415
416        let data = cursor.into_inner();
417        let expected = create_expected(&[0b11111111, 0b00000000, 0b11001100, 0b00110011]);
418        assert_eq!(data, expected);
419    }
420
421    #[test]
422    fn test_split_mut() {
423        let data = [0u8; 4];
424
425        let vec = Vec::from(data);
426        test_split_mut_helper(vec, |v| v.to_vec());
427
428        let bitvec = BitVec::from_vec(vec![0u8; 4]);
429        test_split_mut_helper(bitvec, |v| BitVec::from_vec(v.to_vec()));
430
431        let mut data = [0u8; 4];
432        let bitslice: &mut BitSlice = data.view_bits_mut();
433        test_split_mut_helper(bitslice, |v| BitVec::from_vec(v.to_vec()));
434
435        let mut data = [0u8; 4];
436        let u8_slice = &mut data[..];
437        test_split_mut_helper(u8_slice, |v| v.to_vec());
438    }
439
440    #[test]
441    fn test_alignment_reads_writes() {
442        for offset in 0..8 {
443            let buf = vec![0u8; 4];
444            let mut cursor = BitCursor::new(buf);
445            cursor.set_position(offset);
446            let value = 0xDEADu16;
447            cursor.write_u16::<BigEndian>(value).unwrap();
448            cursor.set_position(offset);
449            let read_value = cursor.read_u16::<BigEndian>().unwrap();
450            assert_eq!(value, read_value, "offset {offset}");
451        }
452    }
453}