1use std::{
2 fmt::LowerHex,
3 io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::{
7 bit_read::BitRead,
8 bit_seek::BitSeek,
9 bit_write::BitWrite,
10 borrow_bits::{BorrowBits, BorrowBitsMut},
11 prelude::*,
12};
13
14#[derive(Debug, Default, Eq, PartialEq)]
15pub struct BitCursor<T> {
16 inner: T,
17 pos: u64,
18}
19
20impl<T> BitCursor<T> {
21 pub fn new(inner: T) -> BitCursor<T> {
25 BitCursor { inner, pos: 0 }
26 }
27
28 pub fn get_mut(&mut self) -> &mut T {
30 &mut self.inner
31 }
32
33 pub fn get_ref(&self) -> &T {
35 &self.inner
36 }
37
38 pub fn into_inner(self) -> T {
40 self.inner
41 }
42
43 pub fn position(&self) -> u64 {
45 self.pos
46 }
47
48 pub fn set_position(&mut self, pos: u64) {
50 self.pos = pos;
51 }
52}
53
54impl<T> BitCursor<T>
55where
56 T: BorrowBits,
57{
58 pub fn split(&self) -> (&BitSlice, &BitSlice) {
59 let bits = self.inner.borrow_bits();
60 bits.split_at(self.pos as usize)
61 }
62}
63
64impl<T> BitCursor<T>
65where
66 T: BorrowBitsMut,
67{
68 pub fn split_mut(&mut self) -> (&mut BitSlice<BitSafeU8>, &mut BitSlice<BitSafeU8>) {
69 let bits = self.inner.borrow_bits_mut();
70 let (left, right) = bits.split_at_mut(self.pos as usize);
71 (left, right)
72 }
73}
74
75impl<T> Clone for BitCursor<T>
76where
77 T: Clone,
78{
79 fn clone(&self) -> Self {
80 BitCursor {
81 inner: self.inner.clone(),
82 pos: self.pos,
83 }
84 }
85}
86
87impl<T> BitSeek for BitCursor<T>
88where
89 T: BorrowBits,
90{
91 fn bit_seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
92 let (base_pos, offset) = match pos {
93 SeekFrom::Start(n) => {
94 self.pos = n;
95 return Ok(n);
96 }
97 SeekFrom::End(n) => (self.inner.borrow_bits().len() as u64, n),
98 SeekFrom::Current(n) => (self.pos, n),
99 };
100 match base_pos.checked_add_signed(offset) {
101 Some(n) => {
102 self.pos = n;
103 Ok(self.pos)
104 }
105 None => Err(std::io::Error::new(
106 std::io::ErrorKind::InvalidInput,
107 "invalid seek to a negative or overlfowing position",
108 )),
109 }
110 }
111}
112
113impl<T> Seek for BitCursor<T>
114where
115 T: BorrowBits,
116{
117 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
118 match pos {
119 SeekFrom::Start(n) => self.bit_seek(SeekFrom::Start(n * 8)),
120 SeekFrom::End(n) => self.bit_seek(SeekFrom::End(n * 8)),
121 SeekFrom::Current(n) => self.bit_seek(SeekFrom::Current(n * 8)),
122 }
123 }
124}
125
126impl<T> Read for BitCursor<T>
127where
128 T: BorrowBits,
129{
130 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
131 let bits = self.inner.borrow_bits();
132 let remaining = &bits[self.pos as usize..];
133 let mut bytes_read = 0;
134
135 for (i, chunk) in remaining.chunks(8).take(buf.len()).enumerate() {
136 let mut byte = 0u8;
137 for (j, bit) in chunk.iter().enumerate() {
138 if *bit {
139 byte |= 1 << (7 - j);
140 }
141 }
142 buf[i] = byte;
143 bytes_read += 1;
144 }
145
146 self.pos += (bytes_read * 8) as u64;
147 Ok(bytes_read)
148 }
149}
150
151impl<T> BitRead for BitCursor<T>
152where
153 T: BorrowBits,
154{
155 fn read_bits(&mut self, dest: &mut BitSlice) -> std::io::Result<usize> {
156 let n = BitRead::read_bits(&mut BitCursor::split(self).1, dest)?;
157 self.pos += n as u64;
158 Ok(n)
159 }
160}
161
162impl<T> Write for BitCursor<T>
163where
164 T: BorrowBitsMut,
165{
166 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167 let n = Write::write(&mut BitCursor::split_mut(self).1, buf)?;
168 self.pos += (n * 8) as u64;
169 Ok(n)
170 }
171
172 fn flush(&mut self) -> std::io::Result<()> {
173 Ok(())
174 }
175}
176
177impl<T> BitWrite for BitCursor<T>
178where
179 T: BorrowBitsMut,
180 BitCursor<T>: std::io::Write,
181{
182 fn write_bits<O: BitStore>(&mut self, source: &BitSlice<O>) -> std::io::Result<usize> {
183 let n = BitWrite::write_bits(&mut BitCursor::split_mut(self).1, source)?;
184 self.pos += n as u64;
185 Ok(n)
186 }
187}
188
189impl<T> LowerHex for BitCursor<T>
190where
191 T: LowerHex,
192{
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
195 }
196}
197
198#[cfg(test)]
199mod test {
200 use std::fmt::Debug;
201 use std::io::{Seek, SeekFrom};
202
203 use crate::prelude::*;
204 use bitvec::bits;
205 use bitvec::bitvec;
206 use bitvec::view::BitView;
207 use nsw_types::*;
208
209 use crate::bit_read::BitRead;
210 use crate::bit_read_exts::BitReadExts;
211 use crate::bit_seek::BitSeek;
212 use crate::bit_write_exts::BitWriteExts;
213 use crate::borrow_bits::{BorrowBits, BorrowBitsMut};
214 use crate::byte_order::NetworkOrder;
215
216 use super::BitCursor;
217
218 fn test_read_bits_hepler<T: BorrowBits>(buf: T, expected: &[u8]) {
219 let expected_bits = expected.view_bits::<Msb0>();
220 let mut cursor = BitCursor::new(buf);
221 let mut read_buf = bitvec![u8, Msb0; 0; expected_bits.len()];
222 assert_eq!(
223 cursor.read_bits(&mut read_buf).unwrap(),
224 expected_bits.len()
225 );
226 assert_eq!(read_buf, expected_bits);
227 }
228
229 #[test]
230 fn test_read_bits() {
231 let data = [0b11110000, 0b00001111];
232
233 let vec = Vec::from(data);
234 test_read_bits_hepler(vec, &data);
235
236 let bitvec = BitVec::from_slice(&data);
237 test_read_bits_hepler(bitvec, &data);
238
239 let bitslice: &BitSlice = data.view_bits();
240 test_read_bits_hepler(bitslice, &data);
241
242 let u8_slice = &data[..];
243 test_read_bits_hepler(u8_slice, &data);
244 }
245
246 #[test]
247 fn test_read_bytes() {
248 let data = BitVec::from_vec(vec![1, 2, 3, 4]);
249 let mut cursor = BitCursor::new(data);
250
251 let mut buf = [0u8; 2];
252 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
253 assert_eq!(buf, [1, 2]);
254 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
255 assert_eq!(buf, [3, 4]);
256 }
257
258 #[test]
259 fn test_bit_seek() {
260 let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
261 let mut cursor = BitCursor::new(data);
262
263 let mut read_buf = bitvec![u8, Msb0; 0; 4];
264
265 cursor.bit_seek(SeekFrom::End(-2)).expect("valid seek");
266 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
268 assert_eq!(read_buf, bits![u8, Msb0; 1, 1, 0, 0]);
269 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
271
272 cursor.bit_seek(SeekFrom::Current(-6)).expect("valid seek");
274 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
275 assert_eq!(read_buf, bits![u8, Msb0; 1, 1, 0, 0]);
276
277 cursor.bit_seek(SeekFrom::Start(4)).expect("valid seek");
278 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
279 assert_eq!(read_buf, bits![u8, Msb0; 1, 1, 0, 0]);
280 }
281
282 #[test]
283 fn test_seek() {
284 let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
285 let mut cursor = BitCursor::new(data);
286
287 let mut read_buf = bitvec![u8, Msb0; 0; 2];
288 cursor.seek(SeekFrom::End(-1)).unwrap();
289 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
291 assert_eq!(read_buf, bits![u8, Msb0; 0, 0]);
292 cursor.seek(SeekFrom::Current(-1)).unwrap();
294 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
296 assert_eq!(read_buf, bits![u8, Msb0; 0, 0]);
297 }
298
299 fn test_write_bits_helper<T: BorrowBitsMut>(buf: T) -> T {
300 let mut cursor = BitCursor::new(buf);
301 cursor.write_u4(u4::new(0b1100)).unwrap();
302 cursor.write_u2(u2::new(0b11)).unwrap();
303 cursor.write_u2(u2::new(0b00)).unwrap();
304 cursor.write_u3(u3::new(0b110)).unwrap();
305 cursor.write_u5(u5::new(0b01100)).unwrap();
306 cursor.into_inner()
307 }
308
309 #[test]
310 fn test_write_bits_bitvec() {
311 let buf = BitVec::from_vec(vec![0; 2]);
312
313 assert_eq!(
314 test_write_bits_helper(buf),
315 BitVec::from_vec(vec![0b11001100, 0b11001100])
316 );
317 }
318
319 #[test]
320 fn test_write_bits_vec() {
321 let buf: Vec<u8> = vec![0, 0];
322
323 assert_eq!(test_write_bits_helper(buf), [0b11001100, 0b11001100]);
324 }
325
326 #[test]
327 fn test_write_bits_bit_slice() {
328 let mut data = [0u8; 2];
329 let buf: &mut BitSlice = data.view_bits_mut::<Msb0>();
330
331 assert_eq!(
332 test_write_bits_helper(buf),
333 BitVec::from_vec(vec![0b11001100, 0b11001100]).as_bitslice()
334 );
335 }
336
337 #[test]
338 fn test_write_bits_u8_slice() {
339 let mut buf = [0u8; 2];
340
341 assert_eq!(
342 test_write_bits_helper(&mut buf[..]),
343 [0b11001100, 0b11001100]
344 );
345 }
346
347 fn test_split_helper<T: BorrowBits>(buf: T, expected: &[u8]) {
348 let expected_bits = expected.view_bits::<Msb0>();
349 let mut cursor = BitCursor::new(buf);
350 cursor.bit_seek(SeekFrom::Current(4)).unwrap();
351 let (before, after) = cursor.split();
352
353 assert_eq!(before, expected_bits[..4]);
354 assert_eq!(after, expected_bits[4..]);
355 }
356
357 #[test]
358 fn test_split() {
359 let data = [0b11110011, 0b10101010];
360
361 let vec = Vec::from(data);
362 test_split_helper(vec, &data);
363
364 let bitvec = BitVec::from_slice(&data);
365 test_split_helper(bitvec, &data);
366
367 let bitslice: &BitSlice = data.view_bits();
368 test_split_helper(bitslice, &data);
369
370 let u8_slice = &data[..];
371 test_split_helper(u8_slice, &data);
372 }
373
374 #[test]
377 fn test_cursors_from_splits() {
378 let data = [0b11110011, 0b10101010];
379
380 let vec = Vec::from(data);
381 let mut vec_cursor = BitCursor::new(vec);
382 vec_cursor.seek(SeekFrom::Start(1)).unwrap();
383 let (left, right) = vec_cursor.split();
384 test_read_bits_hepler(left, &data[..1]);
385 test_read_bits_hepler(right, &data[1..]);
386
387 let bitvec = BitVec::from_slice(&data);
388 let mut bitvec_cursor = BitCursor::new(bitvec);
389 bitvec_cursor.seek(SeekFrom::Start(1)).unwrap();
390 let (left, right) = bitvec_cursor.split();
391 test_read_bits_hepler(left, &data[..1]);
392 test_read_bits_hepler(right, &data[1..]);
393
394 let bitslice: &BitSlice = data.view_bits();
395 let mut bitslice_cursor = BitCursor::new(bitslice);
396 bitslice_cursor.seek(SeekFrom::Start(1)).unwrap();
397 let (left, right) = bitslice_cursor.split();
398 test_read_bits_hepler(left, &data[..1]);
399 test_read_bits_hepler(right, &data[1..]);
400
401 let u8_slice = &data[..];
402 let mut u8_cursor = BitCursor::new(u8_slice);
403 u8_cursor.seek(SeekFrom::Start(1)).unwrap();
404 let (left, right) = u8_cursor.split();
405 test_read_bits_hepler(left, &data[..1]);
406 test_read_bits_hepler(right, &data[1..]);
407 }
408
409 fn test_split_mut_helper<T, U, F>(buf: T, create_expected: F)
411 where
412 T: BorrowBitsMut + PartialEq<U> + Debug,
413 U: Debug,
414 F: FnOnce(&[u8]) -> U,
415 {
416 let mut cursor = BitCursor::new(buf);
417 cursor.seek(SeekFrom::Start(2)).unwrap();
418 {
419 let (mut before, mut after) = cursor.split_mut();
420
421 before
422 .write_u16::<NetworkOrder>(0b1111111100000000)
423 .unwrap();
424 after.write_u16::<NetworkOrder>(0b1100110000110011).unwrap();
425 }
426
427 let data = cursor.into_inner();
428 let expected = create_expected(&[0b11111111, 0b00000000, 0b11001100, 0b00110011]);
429 assert_eq!(data, expected);
430 }
431
432 #[test]
433 fn test_split_mut() {
434 let data = [0u8; 4];
435
436 let vec = Vec::from(data);
437 test_split_mut_helper(vec, |v| v.to_vec());
438
439 let bitvec = BitVec::from_vec(vec![0u8; 4]);
440 test_split_mut_helper(bitvec, |v| BitVec::from_vec(v.to_vec()));
441
442 let mut data = [0u8; 4];
443 let bitslice: &mut BitSlice = data.view_bits_mut();
444 test_split_mut_helper(bitslice, |v| BitVec::from_vec(v.to_vec()));
445
446 let mut data = [0u8; 4];
447 let u8_slice = &mut data[..];
448 test_split_mut_helper(u8_slice, |v| v.to_vec());
449 }
450
451 #[test]
452 fn test_alignment_reads_writes() {
453 for offset in 0..8 {
454 let buf = vec![0u8; 4];
455 let mut cursor = BitCursor::new(buf);
456 cursor.set_position(offset);
457 let value = 0xDEADu16;
458 cursor.write_u16::<BigEndian>(value).unwrap();
459 cursor.set_position(offset);
460 let read_value = cursor.read_u16::<BigEndian>().unwrap();
461 assert_eq!(value, read_value, "offset {offset}");
462 }
463 }
464}