1use std::{
2 fmt::LowerHex,
3 io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7 bit_read::BitRead,
8 bit_seek::BitSeek,
9 bit_write::BitWrite,
10 borrow_bits::{BorrowBits, BorrowBitsMut},
11 prelude::*,
12};
13
14#[derive(Debug, Default, Eq, PartialEq)]
15pub struct BitCursor<T> {
16 inner: T,
17 pos: u64,
18}
19
20impl<T> BitCursor<T> {
21 pub fn new(inner: T) -> BitCursor<T> {
25 BitCursor { inner, pos: 0 }
26 }
27
28 pub fn get_mut(&mut self) -> &mut T {
30 &mut self.inner
31 }
32
33 pub fn get_ref(&self) -> &T {
35 &self.inner
36 }
37
38 pub fn into_inner(self) -> T {
40 self.inner
41 }
42
43 pub fn position(&self) -> u64 {
45 self.pos
46 }
47
48 pub fn set_position(&mut self, pos: u64) {
50 self.pos = pos;
51 }
52}
53
54impl<T> BitCursor<T>
55where
56 T: BorrowBits,
57{
58 pub fn split(&self) -> (&BitSlice, &BitSlice) {
59 let bits = self.inner.borrow_bits();
60 bits.split_at(self.pos as usize)
61 }
62}
63
64impl<T> BitCursor<T>
65where
66 T: BorrowBitsMut,
67{
68 pub fn split_mut(&mut self) -> (&mut BitSlice<impl BitStore>, &mut BitSlice<impl BitStore>) {
69 let bits = self.inner.borrow_bits_mut();
70 let (left, right) = bits.split_at_mut(self.pos as usize);
71 (left, right)
72 }
73}
74
75impl<T> Clone for BitCursor<T>
76where
77 T: Clone,
78{
79 fn clone(&self) -> Self {
80 BitCursor {
81 inner: self.inner.clone(),
82 pos: self.pos,
83 }
84 }
85}
86
87impl<T> BitSeek for BitCursor<T>
88where
89 T: BorrowBits,
90{
91 fn bit_seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
92 let (base_pos, offset) = match pos {
93 SeekFrom::Start(n) => {
94 self.pos = n;
95 return Ok(n);
96 }
97 SeekFrom::End(n) => (self.inner.borrow_bits().len() as u64, n),
98 SeekFrom::Current(n) => (self.pos, n),
99 };
100 match base_pos.checked_add_signed(offset) {
101 Some(n) => {
102 self.pos = n;
103 Ok(self.pos)
104 }
105 None => Err(std::io::Error::new(
106 std::io::ErrorKind::InvalidInput,
107 "invalid seek to a negative or overlfowing position",
108 )),
109 }
110 }
111}
112
113impl<T> Seek for BitCursor<T>
114where
115 T: BorrowBits,
116{
117 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
118 match pos {
119 SeekFrom::Start(n) => self.bit_seek(SeekFrom::Start(n * 8)),
120 SeekFrom::End(n) => self.bit_seek(SeekFrom::End(n * 8)),
121 SeekFrom::Current(n) => self.bit_seek(SeekFrom::Current(n * 8)),
122 }
123 }
124}
125
126impl<T> Read for BitCursor<T>
127where
128 T: BorrowBits,
129{
130 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
131 let bits = self.inner.borrow_bits();
132 let remaining = &bits[self.pos as usize..];
133 let mut bytes_read = 0;
134
135 for (i, chunk) in remaining.chunks(8).take(buf.len()).enumerate() {
136 let mut byte = 0u8;
137 for (j, bit) in chunk.iter().enumerate() {
138 if *bit {
139 byte |= 1 << (7 - j);
140 }
141 }
142 buf[i] = byte;
143 bytes_read += 1;
144 }
145
146 self.pos += (bytes_read * 8) as u64;
147 Ok(bytes_read)
148 }
149}
150
151impl<T> BitRead for BitCursor<T>
152where
153 T: BorrowBits,
154{
155 fn read_bits(&mut self, dest: &mut BitSlice) -> std::io::Result<usize> {
156 let n = BitRead::read_bits(&mut BitCursor::split(self).1, dest)?;
157 self.pos += n as u64;
158 Ok(n)
159 }
160}
161
162impl<T> Write for BitCursor<T>
163where
164 T: BorrowBitsMut,
165{
166 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167 let n = Write::write(&mut BitCursor::split_mut(self).1, buf)?;
168 self.pos += (n * 8) as u64;
169 Ok(n)
170 }
171
172 fn flush(&mut self) -> std::io::Result<()> {
173 Ok(())
174 }
175}
176
177impl<T> BitWrite for BitCursor<T>
178where
179 T: BorrowBitsMut,
180 BitCursor<T>: std::io::Write,
181{
182 fn write_bits<O: BitStore>(&mut self, source: &BitSlice<O>) -> std::io::Result<usize> {
183 let n = BitWrite::write_bits(&mut BitCursor::split_mut(self).1, source)?;
184 self.pos += n as u64;
185 Ok(n)
186 }
187}
188
189impl<T> LowerHex for BitCursor<T>
190where
191 T: LowerHex,
192{
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
195 }
196}
197
198#[cfg(test)]
199mod test {
200 use std::fmt::Debug;
201 use std::io::{Seek, SeekFrom};
202
203 use crate::prelude::*;
204 use bitvec::{order::Msb0, view::BitView};
205 use nsw_types::*;
206
207 fn test_read_bits_hepler<T: BorrowBits>(buf: T, expected: &[u8]) {
208 let expected_bits = expected.view_bits::<Msb0>();
209 let mut cursor = BitCursor::new(buf);
210 let mut read_buf = bitvec![0; expected_bits.len()];
211 assert_eq!(
212 cursor.read_bits(read_buf.as_mut_bitslice()).unwrap(),
213 expected_bits.len()
214 );
215 assert_eq!(read_buf, expected_bits);
216 }
217
218 #[test]
219 fn test_read_bits() {
220 let data = [0b11110000, 0b00001111];
221
222 let vec = Vec::from(data);
223 test_read_bits_hepler(vec, &data);
224
225 let bitvec = BitVec::from_slice(&data);
226 test_read_bits_hepler(bitvec, &data);
227
228 let bitslice: &BitSlice = data.view_bits();
229 test_read_bits_hepler(bitslice, &data);
230
231 let u8_slice = &data[..];
232 test_read_bits_hepler(u8_slice, &data);
233 }
234
235 #[test]
236 fn test_read_bytes() {
237 let data = BitVec::from_vec(vec![1, 2, 3, 4]);
238 let mut cursor = BitCursor::new(data);
239
240 let mut buf = [0u8; 2];
241 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
242 assert_eq!(buf, [1, 2]);
243 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
244 assert_eq!(buf, [3, 4]);
245 }
246
247 #[test]
248 fn test_bit_seek() {
249 let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
250 let mut cursor = BitCursor::new(data);
251
252 let mut read_buf = bitvec![0; 4];
253
254 cursor.bit_seek(SeekFrom::End(-2)).expect("valid seek");
255 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
257 assert_eq!(read_buf, bits![1, 1, 0, 0]);
258 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
260
261 cursor.bit_seek(SeekFrom::Current(-6)).expect("valid seek");
263 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
264 assert_eq!(read_buf, bits![1, 1, 0, 0]);
265
266 cursor.bit_seek(SeekFrom::Start(4)).expect("valid seek");
267 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
268 assert_eq!(read_buf, bits![1, 1, 0, 0]);
269 }
270
271 #[test]
272 fn test_seek() {
273 let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
274 let mut cursor = BitCursor::new(data);
275
276 let mut read_buf = bitvec![0; 2];
277 cursor.seek(SeekFrom::End(-1)).unwrap();
278 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
280 assert_eq!(read_buf, bits![0, 0]);
281 cursor.seek(SeekFrom::Current(-1)).unwrap();
283 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
285 assert_eq!(read_buf, bits![0, 0]);
286 }
287
288 fn test_write_bits_helper<T: BorrowBitsMut>(buf: T) -> T {
289 let mut cursor = BitCursor::new(buf);
290 cursor.write_u4(u4::new(0b1100)).unwrap();
291 cursor.write_u2(u2::new(0b11)).unwrap();
292 cursor.write_u2(u2::new(0b00)).unwrap();
293 cursor.write_u3(u3::new(0b110)).unwrap();
294 cursor.write_u5(u5::new(0b01100)).unwrap();
295 cursor.into_inner()
296 }
297
298 #[test]
299 fn test_write_bits_bitvec() {
300 let buf = BitVec::from_vec(vec![0; 2]);
301
302 assert_eq!(
303 test_write_bits_helper(buf),
304 BitVec::from_vec(vec![0b11001100, 0b11001100])
305 );
306 }
307
308 #[test]
309 fn test_write_bits_vec() {
310 let buf: Vec<u8> = vec![0, 0];
311
312 assert_eq!(test_write_bits_helper(buf), [0b11001100, 0b11001100]);
313 }
314
315 #[test]
316 fn test_write_bits_bit_slice() {
317 let mut data = [0u8; 2];
318 let buf: &mut BitSlice = data.view_bits_mut::<Msb0>();
319
320 assert_eq!(
321 test_write_bits_helper(buf),
322 BitVec::from_vec(vec![0b11001100, 0b11001100]).as_bitslice()
323 );
324 }
325
326 #[test]
327 fn test_write_bits_u8_slice() {
328 let mut buf = [0u8; 2];
329
330 assert_eq!(
331 test_write_bits_helper(&mut buf[..]),
332 [0b11001100, 0b11001100]
333 );
334 }
335
336 fn test_split_helper<T: BorrowBits>(buf: T, expected: &[u8]) {
337 let expected_bits = expected.view_bits::<Msb0>();
338 let mut cursor = BitCursor::new(buf);
339 cursor.bit_seek(SeekFrom::Current(4)).unwrap();
340 let (before, after) = cursor.split();
341
342 assert_eq!(before, expected_bits[..4]);
343 assert_eq!(after, expected_bits[4..]);
344 }
345
346 #[test]
347 fn test_split() {
348 let data = [0b11110011, 0b10101010];
349
350 let vec = Vec::from(data);
351 test_split_helper(vec, &data);
352
353 let bitvec = BitVec::from_slice(&data);
354 test_split_helper(bitvec, &data);
355
356 let bitslice: &BitSlice = data.view_bits();
357 test_split_helper(bitslice, &data);
358
359 let u8_slice = &data[..];
360 test_split_helper(u8_slice, &data);
361 }
362
363 #[test]
366 fn test_cursors_from_splits() {
367 let data = [0b11110011, 0b10101010];
368
369 let vec = Vec::from(data);
370 let mut vec_cursor = BitCursor::new(vec);
371 vec_cursor.seek(SeekFrom::Start(1)).unwrap();
372 let (left, right) = vec_cursor.split();
373 test_read_bits_hepler(left, &data[..1]);
374 test_read_bits_hepler(right, &data[1..]);
375
376 let bitvec = BitVec::from_slice(&data);
377 let mut bitvec_cursor = BitCursor::new(bitvec);
378 bitvec_cursor.seek(SeekFrom::Start(1)).unwrap();
379 let (left, right) = bitvec_cursor.split();
380 test_read_bits_hepler(left, &data[..1]);
381 test_read_bits_hepler(right, &data[1..]);
382
383 let bitslice: &BitSlice = data.view_bits();
384 let mut bitslice_cursor = BitCursor::new(bitslice);
385 bitslice_cursor.seek(SeekFrom::Start(1)).unwrap();
386 let (left, right) = bitslice_cursor.split();
387 test_read_bits_hepler(left, &data[..1]);
388 test_read_bits_hepler(right, &data[1..]);
389
390 let u8_slice = &data[..];
391 let mut u8_cursor = BitCursor::new(u8_slice);
392 u8_cursor.seek(SeekFrom::Start(1)).unwrap();
393 let (left, right) = u8_cursor.split();
394 test_read_bits_hepler(left, &data[..1]);
395 test_read_bits_hepler(right, &data[1..]);
396 }
397
398 fn test_split_mut_helper<T, U, F>(buf: T, create_expected: F)
400 where
401 T: BorrowBitsMut + PartialEq<U> + Debug,
402 U: Debug,
403 F: FnOnce(&[u8]) -> U,
404 {
405 let mut cursor = BitCursor::new(buf);
406 cursor.seek(SeekFrom::Start(2)).unwrap();
407 {
408 let (mut before, mut after) = cursor.split_mut();
409
410 before
411 .write_u16::<NetworkOrder>(0b1111111100000000)
412 .unwrap();
413 after.write_u16::<NetworkOrder>(0b1100110000110011).unwrap();
414 }
415
416 let data = cursor.into_inner();
417 let expected = create_expected(&[0b11111111, 0b00000000, 0b11001100, 0b00110011]);
418 assert_eq!(data, expected);
419 }
420
421 #[test]
422 fn test_split_mut() {
423 let data = [0u8; 4];
424
425 let vec = Vec::from(data);
426 test_split_mut_helper(vec, |v| v.to_vec());
427
428 let bitvec = BitVec::from_vec(vec![0u8; 4]);
429 test_split_mut_helper(bitvec, |v| BitVec::from_vec(v.to_vec()));
430
431 let mut data = [0u8; 4];
432 let bitslice: &mut BitSlice = data.view_bits_mut();
433 test_split_mut_helper(bitslice, |v| BitVec::from_vec(v.to_vec()));
434
435 let mut data = [0u8; 4];
436 let u8_slice = &mut data[..];
437 test_split_mut_helper(u8_slice, |v| v.to_vec());
438 }
439
440 #[test]
441 fn test_alignment_reads_writes() {
442 for offset in 0..8 {
443 let buf = vec![0u8; 4];
444 let mut cursor = BitCursor::new(buf);
445 cursor.set_position(offset);
446 let value = 0xDEADu16;
447 cursor.write_u16::<BigEndian>(value).unwrap();
448 cursor.set_position(offset);
449 let read_value = cursor.read_u16::<BigEndian>().unwrap();
450 assert_eq!(value, read_value, "offset {offset}");
451 }
452 }
453}