1use std::{
2 fmt::LowerHex,
3 io::{Read, Seek, SeekFrom, Write},
4};
5
6use bitvec::{order::Msb0, slice::BitSlice, vec::BitVec, view::BitView};
7
8use crate::{bit_read::BitRead, bit_write::BitWrite};
9
10#[derive(Debug, Default, Eq, PartialEq)]
11pub struct BitCursor<T> {
12 inner: T,
13 pos: u64,
14}
15
16impl<T> BitCursor<T> {
17 pub fn new(inner: T) -> BitCursor<T> {
21 BitCursor { inner, pos: 0 }
22 }
23
24 pub fn into_inner(self) -> T {
26 self.inner
27 }
28
29 pub fn position(&self) -> u64 {
31 self.pos
32 }
33
34 pub fn set_position(&mut self, pos: u64) {
36 self.pos = pos;
37 }
38}
39
40impl BitCursor<BitVec<u8, Msb0>> {
41 pub fn from_vec(data: Vec<u8>) -> Self {
43 Self {
44 inner: BitVec::from_vec(data),
45 pos: 0,
46 }
47 }
48
49 pub fn remaining_slice(&self) -> &BitSlice<u8, Msb0> {
51 let len = self.pos.min(self.inner.capacity() as u64);
52 &self.inner.as_bitslice()[(len as usize)..]
53 }
54
55 pub fn remaining_slice_mut(&mut self) -> &mut BitSlice<u8, Msb0> {
57 let start = self.pos.min(self.inner.capacity() as u64);
58 &mut self.inner.as_mut_bitslice()[(start as usize)..]
59 }
60
61 pub fn is_empty(&self) -> bool {
63 self.pos >= self.remaining_slice().len() as u64
64 }
65}
66
67impl BitCursor<&BitSlice<u8, Msb0>> {
68 pub fn remaining_slice(&self) -> &BitSlice<u8, Msb0> {
70 let len = self.pos.min(self.inner.len() as u64);
71 &self.inner[(len as usize)..]
72 }
73
74 pub fn is_empty(&self) -> bool {
75 self.pos >= self.remaining_slice().len() as u64
76 }
77}
78
79impl BitCursor<&[u8]> {
80 pub fn remaining_slice(&self) -> &BitSlice<u8, Msb0> {
81 let len = self.pos.min((self.inner.len() * 8) as u64);
83 &self.inner.view_bits::<Msb0>()[(len as usize)..]
84 }
85}
86
87impl<T> Clone for BitCursor<T>
88where
89 T: Clone,
90{
91 fn clone(&self) -> Self {
92 BitCursor {
93 inner: self.inner.clone(),
94 pos: self.pos,
95 }
96 }
97}
98
99impl Seek for BitCursor<&BitSlice<u8, Msb0>> {
100 fn seek(&mut self, style: SeekFrom) -> std::io::Result<u64> {
101 let (base_pos, offset) = match style {
102 SeekFrom::Start(n) => {
103 self.pos = n;
104 return Ok(self.pos);
105 }
106 SeekFrom::End(n) => (self.inner.len() as u64, n),
107 SeekFrom::Current(n) => (self.pos, n),
108 };
109 match base_pos.checked_add_signed(offset) {
110 Some(n) => {
111 self.pos = n;
112 Ok(self.pos)
113 }
114 None => Err(std::io::Error::new(
115 std::io::ErrorKind::InvalidInput,
116 "invalid seek to a negative or overflowing position",
117 )),
118 }
119 }
120}
121
122impl Seek for BitCursor<BitVec<u8, Msb0>> {
123 fn seek(&mut self, style: SeekFrom) -> std::io::Result<u64> {
124 let (base_pos, offset) = match style {
125 SeekFrom::Start(n) => {
126 self.pos = n;
127 return Ok(self.pos);
128 }
129 SeekFrom::End(n) => (self.inner.len() as u64, n),
130 SeekFrom::Current(n) => (self.pos, n),
131 };
132 match base_pos.checked_add_signed(offset) {
133 Some(n) => {
134 self.pos = n;
135 Ok(self.pos)
136 }
137 None => Err(std::io::Error::new(
138 std::io::ErrorKind::InvalidInput,
139 "invalid seek to a negative or overflowing position",
140 )),
141 }
142 }
143}
144
145impl Seek for BitCursor<&[u8]> {
146 fn seek(&mut self, style: SeekFrom) -> std::io::Result<u64> {
147 let (base_pos, offset) = match style {
148 SeekFrom::Start(n) => {
149 self.pos = n;
150 return Ok(self.pos);
151 }
152 SeekFrom::End(n) => (self.inner.len() as u64, n),
153 SeekFrom::Current(n) => (self.pos, n),
154 };
155 match base_pos.checked_add_signed(offset) {
156 Some(n) => {
157 self.pos = n;
158 Ok(self.pos)
159 }
160 None => Err(std::io::Error::new(
161 std::io::ErrorKind::InvalidInput,
162 "invalid seek to a negative or overflowing position",
163 )),
164 }
165 }
166}
167
168impl Read for BitCursor<BitVec<u8, Msb0>> {
169 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
170 if self.pos % 8 != 0 {
171 return Err(std::io::Error::new(
172 std::io::ErrorKind::Other,
173 "Attempted byte-level read when not on byte boundary",
174 ));
175 }
176 match self.remaining_slice().read(buf) {
177 Ok(n) => {
178 self.pos += (n * 8) as u64;
179 Ok(n)
180 }
181 Err(e) => Err(e),
182 }
183 }
184}
185
186impl Read for BitCursor<&[u8]> {
187 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
188 if self.pos % 8 != 0 {
189 return Err(std::io::Error::new(
190 std::io::ErrorKind::Other,
191 "Attempted byte-level read when not on byte boundary",
192 ));
193 }
194 match self.remaining_slice().read(buf) {
195 Ok(n) => {
196 self.pos += (n * 8) as u64;
197 Ok(n)
198 }
199 Err(e) => Err(e),
200 }
201 }
202}
203
204impl Read for BitCursor<&BitSlice<u8, Msb0>> {
205 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
206 if self.pos % 8 != 0 {
207 return Err(std::io::Error::new(
208 std::io::ErrorKind::Other,
209 "Attempted byte-level read when not on byte boundary",
210 ));
211 }
212 match self.remaining_slice().read(buf) {
213 Ok(n) => {
214 self.pos += (n * 8) as u64;
215 Ok(n)
216 }
217 Err(e) => Err(e),
218 }
219 }
220}
221
222impl BitRead for BitCursor<BitVec<u8, Msb0>> {
223 fn read_bits(&mut self, buf: &mut [nsw_types::u1]) -> std::io::Result<usize> {
224 let n = BitRead::read_bits(&mut self.remaining_slice(), buf)?;
225 self.pos += n as u64;
226 Ok(n)
227 }
228}
229
230impl BitRead for BitCursor<&BitSlice<u8, Msb0>> {
231 fn read_bits(&mut self, buf: &mut [nsw_types::u1]) -> std::io::Result<usize> {
232 let n = BitRead::read_bits(&mut self.remaining_slice(), buf)?;
233 self.pos += n as u64;
234 Ok(n)
235 }
236}
237
238impl BitRead for BitCursor<&[u8]> {
239 fn read_bits(&mut self, buf: &mut [nsw_types::u1]) -> std::io::Result<usize> {
240 let n = BitRead::read_bits(&mut self.remaining_slice(), buf)?;
241 self.pos += n as u64;
242 Ok(n)
243 }
244}
245
246impl Write for BitCursor<BitVec<u8, Msb0>> {
247 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
248 if self.pos % 8 != 0 {
249 return Err(std::io::Error::new(
250 std::io::ErrorKind::Other,
251 "Attempted byte-level write when not on byte boundary",
252 ));
253 }
254 match self.remaining_slice_mut().write(buf) {
255 Ok(n) => {
256 self.pos += (n * 8) as u64;
257 Ok(n)
258 }
259 Err(e) => Err(e),
260 }
261 }
262
263 fn flush(&mut self) -> std::io::Result<()> {
264 Ok(())
265 }
266}
267
268impl Write for BitCursor<&mut BitSlice<u8, Msb0>> {
269 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
270 if self.pos % 8 != 0 {
271 return Err(std::io::Error::new(
272 std::io::ErrorKind::Other,
273 "Attempted byte-level write when not on byte boundary",
274 ));
275 }
276 match self.inner.write(buf) {
277 Ok(n) => {
278 self.pos += (n * 8) as u64;
279 Ok(n)
280 }
281 Err(e) => Err(e),
282 }
283 }
284
285 fn flush(&mut self) -> std::io::Result<()> {
286 Ok(())
287 }
288}
289
290impl BitWrite for BitCursor<BitVec<u8, Msb0>> {
291 fn write_bits(&mut self, buf: &[nsw_types::u1]) -> std::io::Result<usize> {
292 let n = BitWrite::write_bits(&mut self.remaining_slice_mut(), buf)?;
293 self.pos += n as u64;
294 Ok(n)
295 }
296}
297
298impl BitWrite for BitCursor<&mut BitSlice<u8, Msb0>> {
299 fn write_bits(&mut self, buf: &[nsw_types::u1]) -> std::io::Result<usize> {
300 let n = BitWrite::write_bits(&mut self.inner, buf)?;
301 self.pos += n as u64;
302 Ok(n)
303 }
304}
305
306impl<T> LowerHex for BitCursor<T>
307where
308 T: LowerHex,
309{
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
312 }
313}
314
315#[cfg(test)]
316mod test {
317 use std::io::{Seek, SeekFrom};
318
319 use bitvec::{bits, order::Msb0, vec::BitVec};
320 use nsw_types::u1;
321
322 use crate::{bit_read::BitRead, bit_read_exts::BitReadExts, sub_cursor::SubCursor};
323
324 use super::BitCursor;
325
326 #[test]
327 fn test_read() {
328 let data = BitVec::<u8, Msb0>::from_vec(vec![0b11110000, 0b00001111]);
329 let mut cursor = BitCursor::new(data);
330
331 let mut read_buf = [u1::new(0); 4];
332 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
333 assert_eq!(read_buf, [u1::new(1), u1::new(1), u1::new(1), u1::new(1)]);
334
335 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
336 assert_eq!(read_buf, [u1::new(0), u1::new(0), u1::new(0), u1::new(0)]);
337
338 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
339 assert_eq!(read_buf, [u1::new(0), u1::new(0), u1::new(0), u1::new(0)]);
340
341 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
342 assert_eq!(read_buf, [u1::new(1), u1::new(1), u1::new(1), u1::new(1)]);
343
344 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
345 }
346
347 #[test]
348 fn test_seek() {
349 let data = BitVec::<u8, Msb0>::from_vec(vec![0b11001100, 0b00110011]);
350 let mut cursor = BitCursor::new(data);
351
352 let mut read_buf = [u1::new(0); 2];
353
354 cursor.seek(SeekFrom::End(-2)).expect("valid seek");
355 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
357 assert_eq!(read_buf, [u1::new(1), u1::new(1)]);
358 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
359
360 cursor.seek(SeekFrom::Current(-4)).expect("valid seek");
362 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
363 assert_eq!(read_buf, [u1::new(0), u1::new(0)]);
364
365 cursor.seek(SeekFrom::Start(4)).expect("valid seek");
366 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
367 assert_eq!(read_buf, [u1::new(1), u1::new(1)]);
368 }
369
370 #[test]
371 fn test_read_bytes() {
372 let data = BitVec::<u8, Msb0>::from_vec(vec![1, 2, 3, 4]);
373 let mut cursor = BitCursor::new(data);
374
375 let mut buf = [0u8; 2];
376 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
377 assert_eq!(buf, [1, 2]);
378 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
379 assert_eq!(buf, [3, 4]);
380 }
381
382 #[test]
383 fn test_sub_cursor_vec() {
384 let data = BitVec::<u8, Msb0>::from_vec(vec![1, 2, 3, 4]);
385 let mut cursor = BitCursor::new(data);
386
387 let _ = cursor.read_u8().unwrap();
388 let mut sub_cursor = cursor.sub_cursor(0..24);
389
390 assert_eq!(sub_cursor.remaining_slice().len(), 24);
391 assert_eq!(sub_cursor.read_u8().unwrap(), 2);
392 }
393
394 #[test]
395 fn test_remaining_slice_u8() {
396 let data: Vec<u8> = vec![0b00001111, 0b10101010];
397
398 let mut cursor = BitCursor::new(&data[..]);
399 cursor.read_u4().unwrap();
400
401 let slice = cursor.remaining_slice();
402 assert_eq!(slice, bits![u8, Msb0; 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0]);
403 }
404}