1use std::{
2 fmt::LowerHex,
3 io::{Read, Seek, SeekFrom, Write},
4};
5
6use crate::prelude::*;
7
8#[derive(Debug, Default, Eq, PartialEq)]
9pub struct BitCursor<T> {
10 inner: T,
11 pos: u64,
12}
13
14impl<T> BitCursor<T> {
15 pub fn new(inner: T) -> BitCursor<T> {
19 BitCursor { inner, pos: 0 }
20 }
21
22 pub fn get_mut(&mut self) -> &mut T {
24 &mut self.inner
25 }
26
27 pub fn get_ref(&self) -> &T {
29 &self.inner
30 }
31
32 pub fn into_inner(self) -> T {
34 self.inner
35 }
36
37 pub fn position(&self) -> u64 {
39 self.pos
40 }
41
42 pub fn set_position(&mut self, pos: u64) {
44 self.pos = pos;
45 }
46}
47
48impl<T> BitCursor<T>
49where
50 T: BorrowBits,
51{
52 pub fn split(&self) -> (&BitSlice, &BitSlice) {
54 let bits = self.inner.borrow_bits();
55 bits.split_at(self.pos as usize)
56 }
57}
58
59impl<T> BitCursor<T>
60where
61 T: BorrowBitsMut,
62{
63 pub fn split_mut(
66 &mut self,
67 ) -> (
68 &mut BitSlice<bitvec::access::BitSafeU8>,
69 &mut BitSlice<bitvec::access::BitSafeU8>,
70 ) {
71 let bits = self.inner.borrow_bits_mut();
72 let (left, right) = bits.split_at_mut(self.pos as usize);
73 (left, right)
74 }
75}
76
77impl<T> Clone for BitCursor<T>
78where
79 T: Clone,
80{
81 fn clone(&self) -> Self {
82 BitCursor {
83 inner: self.inner.clone(),
84 pos: self.pos,
85 }
86 }
87}
88
89impl<T> BitSeek for BitCursor<T>
90where
91 T: BorrowBits,
92{
93 fn bit_seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
94 let (base_pos, offset) = match pos {
95 SeekFrom::Start(n) => {
96 self.pos = n;
97 return Ok(n);
98 }
99 SeekFrom::End(n) => (self.inner.borrow_bits().len() as u64, n),
100 SeekFrom::Current(n) => (self.pos, n),
101 };
102 match base_pos.checked_add_signed(offset) {
103 Some(n) => {
104 self.pos = n;
105 Ok(self.pos)
106 }
107 None => Err(std::io::Error::new(
108 std::io::ErrorKind::InvalidInput,
109 "invalid seek to a negative or overlfowing position",
110 )),
111 }
112 }
113}
114
115impl<T> Seek for BitCursor<T>
116where
117 T: BorrowBits,
118{
119 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
120 match pos {
121 SeekFrom::Start(n) => self.bit_seek(SeekFrom::Start(n * 8)),
122 SeekFrom::End(n) => self.bit_seek(SeekFrom::End(n * 8)),
123 SeekFrom::Current(n) => self.bit_seek(SeekFrom::Current(n * 8)),
124 }
125 }
126}
127
128impl<T> Read for BitCursor<T>
129where
130 T: BorrowBits,
131{
132 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
133 let bits = self.inner.borrow_bits();
134 let remaining = &bits[self.pos as usize..];
135 let mut bytes_read = 0;
136
137 for (i, chunk) in remaining.chunks(8).take(buf.len()).enumerate() {
138 let mut byte = 0u8;
139 for (j, bit) in chunk.iter().enumerate() {
140 if *bit {
141 byte |= 1 << (7 - j);
142 }
143 }
144 buf[i] = byte;
145 bytes_read += 1;
146 }
147
148 self.pos += (bytes_read * 8) as u64;
149 Ok(bytes_read)
150 }
151}
152
153impl<T> BitRead for BitCursor<T>
154where
155 T: BorrowBits,
156{
157 fn read_bits<O: BitStore>(&mut self, dest: &mut BitSlice<O>) -> std::io::Result<usize> {
158 let n = BitRead::read_bits(&mut BitCursor::split(self).1, dest)?;
159 self.pos += n as u64;
160 Ok(n)
161 }
162}
163
164impl<T> Write for BitCursor<T>
165where
166 T: BorrowBitsMut,
167{
168 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
169 let n = Write::write(&mut BitCursor::split_mut(self).1, buf)?;
170 self.pos += (n * 8) as u64;
171 Ok(n)
172 }
173
174 fn flush(&mut self) -> std::io::Result<()> {
175 Ok(())
176 }
177}
178
179impl<T> BitWrite for BitCursor<T>
180where
181 T: BorrowBitsMut,
182 BitCursor<T>: std::io::Write,
183{
184 fn write_bits<O: BitStore>(&mut self, source: &BitSlice<O>) -> std::io::Result<usize> {
185 let n = BitWrite::write_bits(&mut BitCursor::split_mut(self).1, source)?;
186 self.pos += n as u64;
187 Ok(n)
188 }
189}
190
191impl<T> LowerHex for BitCursor<T>
192where
193 T: LowerHex,
194{
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 write!(f, "buf: {:x}, pos: {}", self.inner, self.pos)
197 }
198}
199
200#[cfg(test)]
201mod test {
202 use std::fmt::Debug;
203 use std::io::{Seek, SeekFrom};
204
205 use crate::prelude::*;
206 use bitvec::{order::Msb0, view::BitView};
207
208 fn test_read_bits_hepler<T: BorrowBits>(buf: T, expected: &[u8]) {
209 let expected_bits = expected.view_bits::<Msb0>();
210 let mut cursor = BitCursor::new(buf);
211 let mut read_buf = bitvec![0; expected_bits.len()];
212 assert_eq!(
213 cursor.read_bits(read_buf.as_mut_bitslice()).unwrap(),
214 expected_bits.len()
215 );
216 assert_eq!(read_buf, expected_bits);
217 }
218
219 #[test]
220 fn test_read_bits() {
221 let data = [0b11110000, 0b00001111];
222
223 let vec = Vec::from(data);
224 test_read_bits_hepler(vec, &data);
225
226 let bitvec = BitVec::from_slice(&data);
227 test_read_bits_hepler(bitvec, &data);
228
229 let bitslice: &BitSlice = data.view_bits();
230 test_read_bits_hepler(bitslice, &data);
231
232 let u8_slice = &data[..];
233 test_read_bits_hepler(u8_slice, &data);
234 }
235
236 #[test]
237 fn test_read_bytes() {
238 let data = BitVec::from_vec(vec![1, 2, 3, 4]);
239 let mut cursor = BitCursor::new(data);
240
241 let mut buf = [0u8; 2];
242 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
243 assert_eq!(buf, [1, 2]);
244 std::io::Read::read(&mut cursor, &mut buf).expect("valid read");
245 assert_eq!(buf, [3, 4]);
246 }
247
248 #[test]
249 fn test_bit_seek() {
250 let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
251 let mut cursor = BitCursor::new(data);
252
253 let mut read_buf = bitvec![0; 4];
254
255 cursor.bit_seek(SeekFrom::End(-2)).expect("valid seek");
256 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
258 assert_eq!(read_buf, bits![1, 1, 0, 0]);
259 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 0);
261
262 cursor.bit_seek(SeekFrom::Current(-6)).expect("valid seek");
264 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
265 assert_eq!(read_buf, bits![1, 1, 0, 0]);
266
267 cursor.bit_seek(SeekFrom::Start(4)).expect("valid seek");
268 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 4);
269 assert_eq!(read_buf, bits![1, 1, 0, 0]);
270 }
271
272 #[test]
273 fn test_seek() {
274 let data = BitVec::from_vec(vec![0b11001100, 0b00110011]);
275 let mut cursor = BitCursor::new(data);
276
277 let mut read_buf = bitvec![0; 2];
278 cursor.seek(SeekFrom::End(-1)).unwrap();
279 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
281 assert_eq!(read_buf, bits![0, 0]);
282 cursor.seek(SeekFrom::Current(-1)).unwrap();
284 assert_eq!(cursor.read_bits(&mut read_buf).unwrap(), 2);
286 assert_eq!(read_buf, bits![0, 0]);
287 }
288
289 fn test_write_bits_helper<T>(buf: T)
290 where
291 T: BorrowBitsMut + Debug,
292 BitCursor<T>: std::io::Write,
293 {
294 let mut cursor = BitCursor::new(buf);
295 let data = bits![1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0];
296 assert_eq!(cursor.write_bits(data).unwrap(), 16);
297 assert_eq!(cursor.into_inner().borrow_bits(), data);
298 }
299
300 #[test]
301 fn test_write_bits_bitvec() {
302 let buf = BitVec::from_vec(vec![0; 2]);
303 test_write_bits_helper(buf);
304 }
305
306 #[test]
307 fn test_write_bits_vec() {
308 let buf: Vec<u8> = vec![0, 0];
309 test_write_bits_helper(buf);
310 }
311
312 #[test]
313 fn test_write_bits_bit_slice() {
314 let mut buf = bitvec![0; 16];
315 test_write_bits_helper(buf.as_mut_bitslice());
316 }
317
318 #[test]
319 fn test_write_bits_u8_slice() {
320 let mut buf = [0u8; 2];
321 test_write_bits_helper(&mut buf[..]);
322 }
323
324 fn test_split_helper<T: BorrowBits>(buf: T, expected: &[u8]) {
325 let expected_bits = expected.view_bits::<Msb0>();
326 let mut cursor = BitCursor::new(buf);
327 cursor.bit_seek(SeekFrom::Current(4)).unwrap();
328 let (before, after) = cursor.split();
329
330 assert_eq!(before, expected_bits[..4]);
331 assert_eq!(after, expected_bits[4..]);
332 }
333
334 #[test]
335 fn test_split() {
336 let data = [0b11110011, 0b10101010];
337
338 let vec = Vec::from(data);
339 test_split_helper(vec, &data);
340
341 let bitvec = BitVec::from_slice(&data);
342 test_split_helper(bitvec, &data);
343
344 let bitslice: &BitSlice = data.view_bits();
345 test_split_helper(bitslice, &data);
346
347 let u8_slice = &data[..];
348 test_split_helper(u8_slice, &data);
349 }
350
351 #[test]
354 fn test_cursors_from_splits() {
355 let data = [0b11110011, 0b10101010];
356
357 let vec = Vec::from(data);
358 let mut vec_cursor = BitCursor::new(vec);
359 vec_cursor.seek(SeekFrom::Start(1)).unwrap();
360 let (left, right) = vec_cursor.split();
361 test_read_bits_hepler(left, &data[..1]);
362 test_read_bits_hepler(right, &data[1..]);
363
364 let bitvec = BitVec::from_slice(&data);
365 let mut bitvec_cursor = BitCursor::new(bitvec);
366 bitvec_cursor.seek(SeekFrom::Start(1)).unwrap();
367 let (left, right) = bitvec_cursor.split();
368 test_read_bits_hepler(left, &data[..1]);
369 test_read_bits_hepler(right, &data[1..]);
370
371 let bitslice: &BitSlice = data.view_bits();
372 let mut bitslice_cursor = BitCursor::new(bitslice);
373 bitslice_cursor.seek(SeekFrom::Start(1)).unwrap();
374 let (left, right) = bitslice_cursor.split();
375 test_read_bits_hepler(left, &data[..1]);
376 test_read_bits_hepler(right, &data[1..]);
377
378 let u8_slice = &data[..];
379 let mut u8_cursor = BitCursor::new(u8_slice);
380 u8_cursor.seek(SeekFrom::Start(1)).unwrap();
381 let (left, right) = u8_cursor.split();
382 test_read_bits_hepler(left, &data[..1]);
383 test_read_bits_hepler(right, &data[1..]);
384 }
385
386 fn test_split_mut_helper<T>(buf: T)
388 where
389 T: BorrowBitsMut + Debug,
390 BitCursor<T>: std::io::Write,
391 {
392 let mut cursor = BitCursor::new(buf);
393 cursor.seek(SeekFrom::Start(2)).unwrap();
394
395 {
396 let (mut left, mut right) = cursor.split_mut();
397
398 left.write_bits(bits![1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0])
399 .unwrap();
400 right
401 .write_bits(bits![0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1])
402 .unwrap();
403 }
404
405 let data = cursor.into_inner();
406 assert_eq!(
407 data.borrow_bits(),
408 bits![
409 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
410 1, 1, 1, 1
411 ]
412 );
413 }
414
415 #[test]
416 fn test_split_mut() {
417 let data = [0u8; 4];
418
419 let vec = Vec::from(data);
420 test_split_mut_helper(vec);
421
422 let bitvec = BitVec::from_vec(vec![0u8; 4]);
423 test_split_mut_helper(bitvec);
424
425 let mut data = [0u8; 4];
426 let bitslice: &mut BitSlice = data.view_bits_mut();
427 test_split_mut_helper(bitslice);
428
429 let mut data = [0u8; 4];
430 let u8_slice = &mut data[..];
431 test_split_mut_helper(u8_slice);
432 }
433
434 #[test]
435 fn test_alignment_reads_writes() {
436 for offset in 0..8 {
437 let buf = vec![0u8; 4];
438 let mut cursor = BitCursor::new(buf);
439 cursor.set_position(offset);
440 let value = BitVec::from_slice(&[0xDE, 0xAD]);
441
442 cursor.write_bits(value.as_bitslice()).unwrap();
443
444 cursor.set_position(offset);
445 let mut read_buf = BitVec::with_capacity(16);
446 read_buf.resize(16, false);
447 cursor.read_bits(read_buf.as_mut_bitslice()).unwrap();
448 assert_eq!(value, read_buf, "offset {offset}");
449 }
450 }
451}