bit_cursor/
bit_cursor.rs

1use std::{
2    fmt::LowerHex,
3    io::{Read, Seek, SeekFrom, Write},
4};
5
6use bitvec::{order::Msb0, slice::BitSlice, vec::BitVec, view::BitView};
7
8use crate::{bit_read::BitRead, bit_write::BitWrite};
9
10#[derive(Debug, Default, Eq, PartialEq)]
11pub struct BitCursor<T> {
12    inner: T,
13    pos: u64,
14}
15
16impl<T> BitCursor<T> {
17    /// Creates a new cursor wrapping the provided buffer.
18    ///
19    /// Cursor initial position is `0` even if the given buffer is not empty.
20    pub fn new(inner: T) -> BitCursor<T> {
21        BitCursor { inner, pos: 0 }
22    }
23
24    /// Consumes the cursor, returning the inner value.
25    pub fn into_inner(self) -> T {
26        self.inner
27    }
28
29    /// Returns the position (in _bits_ since the start) of this cursor.
30    pub fn position(&self) -> u64 {
31        self.pos
32    }
33
34    /// Sets the position of this cursor (in _bits_ since the start)
35    pub fn set_position(&mut self, pos: u64) {
36        self.pos = pos;
37    }
38}
39
40impl BitCursor<BitVec<u8, Msb0>> {
41    /// Create a BitCursor from a [`Vec<u8>`]
42    pub fn from_vec(data: Vec<u8>) -> Self {
43        Self {
44            inner: BitVec::from_vec(data),
45            pos: 0,
46        }
47    }
48
49    /// Get the data between the current cursor position and the end of the data as a [`BitSlice`]
50    pub fn remaining_slice(&self) -> &BitSlice<u8, Msb0> {
51        let len = self.pos.min(self.inner.capacity() as u64);
52        &self.inner.as_bitslice()[(len as usize)..]
53    }
54
55    /// Get the data between the current cursor position and the end of the data as a mutable [`BitSlice`]
56    pub fn remaining_slice_mut(&mut self) -> &mut BitSlice<u8, Msb0> {
57        let start = self.pos.min(self.inner.capacity() as u64);
58        &mut self.inner.as_mut_bitslice()[(start as usize)..]
59    }
60
61    /// Returns true if the remaining slice is empty
62    pub fn is_empty(&self) -> bool {
63        self.pos >= self.remaining_slice().len() as u64
64    }
65}
66
67impl BitCursor<&BitSlice<u8, Msb0>> {
68    /// Get the data between the current cursor position and the end of the data as a [`BitSlice`]
69    pub fn remaining_slice(&self) -> &BitSlice<u8, Msb0> {
70        let len = self.pos.min(self.inner.len() as u64);
71        &self.inner[(len as usize)..]
72    }
73
74    pub fn is_empty(&self) -> bool {
75        self.pos >= self.remaining_slice().len() as u64
76    }
77}
78
79impl BitCursor<&[u8]> {
80    pub fn remaining_slice(&self) -> &BitSlice<u8, Msb0> {
81        // Here we have to mulitply the slice length by 8, since it's in bytes
82        let len = self.pos.min((self.inner.len() * 8) as u64);
83        &self.inner.view_bits::<Msb0>()[(len as usize)..]
84    }
85}
86
87impl<T> Clone for BitCursor<T>
88where
89    T: Clone,
90{
91    fn clone(&self) -> Self {
92        BitCursor {
93            inner: self.inner.clone(),
94            pos: self.pos,
95        }
96    }
97}
98
99impl Seek for BitCursor<&BitSlice<u8, Msb0>> {
100    fn seek(&mut self, style: SeekFrom) -> std::io::Result<u64> {
101        let (base_pos, offset) = match style {
102            SeekFrom::Start(n) => {
103                self.pos = n;
104                return Ok(self.pos);
105            }
106            SeekFrom::End(n) => (self.inner.len() as u64, n),
107            SeekFrom::Current(n) => (self.pos, n),
108        };
109        match base_pos.checked_add_signed(offset) {
110            Some(n) => {
111                self.pos = n;
112                Ok(self.pos)
113            }
114            None => Err(std::io::Error::new(
115                std::io::ErrorKind::InvalidInput,
116                "invalid seek to a negative or overflowing position",
117            )),
118        }
119    }
120}
121
122impl Seek for BitCursor<BitVec<u8, Msb0>> {
123    fn seek(&mut self, style: SeekFrom) -> std::io::Result<u64> {
124        let (base_pos, offset) = match style {
125            SeekFrom::Start(n) => {
126                self.pos = n;
127                return Ok(self.pos);
128            }
129            SeekFrom::End(n) => (self.inner.len() as u64, n),
130            SeekFrom::Current(n) => (self.pos, n),
131        };
132        match base_pos.checked_add_signed(offset) {
133            Some(n) => {
134                self.pos = n;
135                Ok(self.pos)
136            }
137            None => Err(std::io::Error::new(
138                std::io::ErrorKind::InvalidInput,
139                "invalid seek to a negative or overflowing position",
140            )),
141        }
142    }
143}
144
145impl Seek for BitCursor<&[u8]> {
146    fn seek(&mut self, style: SeekFrom) -> std::io::Result<u64> {
147        let (base_pos, offset) = match style {
148            SeekFrom::Start(n) => {
149                self.pos = n;
150                return Ok(self.pos);
151            }
152            SeekFrom::End(n) => (self.inner.len() as u64, n),
153            SeekFrom::Current(n) => (self.pos, n),
154        };
155        match base_pos.checked_add_signed(offset) {
156            Some(n) => {
157                self.pos = n;
158                Ok(self.pos)
159            }
160            None => Err(std::io::Error::new(
161                std::io::ErrorKind::InvalidInput,
162                "invalid seek to a negative or overflowing position",
163            )),
164        }
165    }
166}
167
168impl Read for BitCursor<BitVec<u8, Msb0>> {
169    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
170        if self.pos % 8 != 0 {
171            return Err(std::io::Error::new(
172                std::io::ErrorKind::Other,
173                "Attempted byte-level read when not on byte boundary",
174            ));
175        }
176        match self.remaining_slice().read(buf) {
177            Ok(n) => {
178                self.pos += (n * 8) as u64;
179                Ok(n)
180            }
181            Err(e) => Err(e),
182        }
183    }
184}
185
186impl Read for BitCursor<&[u8]> {
187    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
188        if self.pos % 8 != 0 {
189            return Err(std::io::Error::new(
190                std::io::ErrorKind::Other,
191                "Attempted byte-level read when not on byte boundary",
192            ));
193        }
194        match self.remaining_slice().read(buf) {
195            Ok(n) => {
196                self.pos += (n * 8) as u64;
197                Ok(n)
198            }
199            Err(e) => Err(e),
200        }
201    }
202}
203
204impl Read for BitCursor<&BitSlice<u8, Msb0>> {
205    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
206        if self.pos % 8 != 0 {
207            return Err(std::io::Error::new(
208                std::io::ErrorKind::Other,
209                "Attempted byte-level read when not on byte boundary",
210            ));
211        }
212        match self.remaining_slice().read(buf) {
213            Ok(n) => {
214                self.pos += (n * 8) as u64;
215                Ok(n)
216            }
217            Err(e) => Err(e),
218        }
219    }
220}
221
222impl BitRead for BitCursor<BitVec<u8, Msb0>> {
223    fn read_bits(&mut self, buf: &mut [nsw_types::u1]) -> std::io::Result<usize> {
224        let n = BitRead::read_bits(&mut self.remaining_slice(), buf)?;
225        self.pos += n as u64;
226        Ok(n)
227    }
228}
229
230impl BitRead for BitCursor<&BitSlice<u8, Msb0>> {
231    fn read_bits(&mut self, buf: &mut [nsw_types::u1]) -> std::io::Result<usize> {
232        let n = BitRead::read_bits(&mut self.remaining_slice(), buf)?;
233        self.pos += n as u64;
234        Ok(n)
235    }
236}
237
238impl BitRead for BitCursor<&[u8]> {
239    fn read_bits(&mut self, buf: &mut [nsw_types::u1]) -> std::io::Result<usize> {
240        let n = BitRead::read_bits(&mut self.remaining_slice(), buf)?;
241        self.pos += n as u64;
242        Ok(n)
243    }
244}
245
246impl Write for BitCursor<BitVec<u8, Msb0>> {
247    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
248        if self.pos % 8 != 0 {
249            return Err(std::io::Error::new(
250                std::io::ErrorKind::Other,
251                "Attempted byte-level write when not on byte boundary",
252            ));
253        }
254        match self.remaining_slice_mut().write(buf) {
255            Ok(n) => {
256                self.pos += (n * 8) as u64;
257                Ok(n)
258            }
259            Err(e) => Err(e),
260        }
261    }
262
263    fn flush(&mut self) -> std::io::Result<()> {
264        Ok(())
265    }
266}
267
268impl Write for BitCursor<&mut BitSlice<u8, Msb0>> {
269    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
270        if self.pos % 8 != 0 {
271            return Err(std::io::Error::new(
272                std::io::ErrorKind::Other,
273                "Attempted byte-level write when not on byte boundary",
274            ));
275        }
276        match self.inner.write(buf) {
277            Ok(n) => {
278                self.pos += (n * 8) as u64;
279                Ok(n)
280            }
281            Err(e) => Err(e),
282        }
283    }
284
285    fn flush(&mut self) -> std::io::Result<()> {
286        Ok(())
287    }
288}
289
290impl BitWrite for BitCursor<BitVec<u8, Msb0>> {
291    fn write_bits(&mut self, buf: &[nsw_types::u1]) -> std::io::Result<usize> {
292        let n = BitWrite::write_bits(&mut self.remaining_slice_mut(), buf)?;
293        self.pos += n as u64;
294        Ok(n)
295    }
296}
297
298impl BitWrite for BitCursor<&mut BitSlice<u8, Msb0>> {
299    fn write_bits(&mut self, buf: &[nsw_types::u1]) -> std::io::Result<usize> {
300        let n = BitWrite::write_bits(&mut self.inner, buf)?;
301        self.pos += n as u64;
302        Ok(n)
303    }
304}
305
306impl<T> LowerHex for BitCursor<T>
307where
308    T: LowerHex,
309{
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
312    }
313}
314
315#[cfg(test)]
316mod test {
317    use std::io::{Seek, SeekFrom};
318
319    use bitvec::{bits, order::Msb0, vec::BitVec};
320    use nsw_types::u1;
321
322    use crate::{bit_read::BitRead, bit_read_exts::BitReadExts, sub_cursor::SubCursor};
323
324    use super::BitCursor;
325
326    #[test]
327    fn test_read() {
328        let data = BitVec::<u8, Msb0>::from_vec(vec![0b11110000, 0b00001111]);
329        let mut cursor = BitCursor::new(data);
330
331        let mut read_buf = [u1::new(0); 4];
332        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
333        assert_eq!(read_buf, [u1::new(1), u1::new(1), u1::new(1), u1::new(1)]);
334
335        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
336        assert_eq!(read_buf, [u1::new(0), u1::new(0), u1::new(0), u1::new(0)]);
337
338        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
339        assert_eq!(read_buf, [u1::new(0), u1::new(0), u1::new(0), u1::new(0)]);
340
341        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
342        assert_eq!(read_buf, [u1::new(1), u1::new(1), u1::new(1), u1::new(1)]);
343
344        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
345    }
346
347    #[test]
348    fn test_seek() {
349        let data = BitVec::<u8, Msb0>::from_vec(vec![0b11001100, 0b00110011]);
350        let mut cursor = BitCursor::new(data);
351
352        let mut read_buf = [u1::new(0); 2];
353
354        cursor.seek(SeekFrom::End(-2)).expect("valid seek");
355        // Should now be reading the last 2 bits
356        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
357        assert_eq!(read_buf, [u1::new(1), u1::new(1)]);
358        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
359
360        // Now 4 bits from the end
361        cursor.seek(SeekFrom::Current(-4)).expect("valid seek");
362        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
363        assert_eq!(read_buf, [u1::new(0), u1::new(0)]);
364
365        cursor.seek(SeekFrom::Start(4)).expect("valid seek");
366        assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
367        assert_eq!(read_buf, [u1::new(1), u1::new(1)]);
368    }
369
370    #[test]
371    fn test_read_bytes() {
372        let data = BitVec::<u8, Msb0>::from_vec(vec![1, 2, 3, 4]);
373        let mut cursor = BitCursor::new(data);
374
375        let mut buf = [0u8; 2];
376        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
377        assert_eq!(buf, [1, 2]);
378        std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
379        assert_eq!(buf, [3, 4]);
380    }
381
382    #[test]
383    fn test_sub_cursor_vec() {
384        let data = BitVec::<u8, Msb0>::from_vec(vec![1, 2, 3, 4]);
385        let mut cursor = BitCursor::new(data);
386
387        let _ = cursor.read_u8().unwrap();
388        let mut sub_cursor = cursor.sub_cursor(0..24);
389
390        assert_eq!(sub_cursor.remaining_slice().len(), 24);
391        assert_eq!(sub_cursor.read_u8().unwrap(), 2);
392    }
393
394    #[test]
395    fn test_remaining_slice_u8() {
396        let data: Vec<u8> = vec![0b00001111, 0b10101010];
397
398        let mut cursor = BitCursor::new(&data[..]);
399        cursor.read_u4().unwrap();
400
401        let slice = cursor.remaining_slice();
402        assert_eq!(slice, bits![u8, Msb0; 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0]);
403    }
404}