1use std::{
2 io::{self, Write},
3 fmt::{self, Debug, Formatter},
4};
5
6use crate::*;
7use io_utils::{Writer, CursorVecU8};
8
9const MASK8: [u8; 9] = [0x00, 0x01, 0x03, 0x07, 0x0F, 0x1F, 0x3F, 0x7F, 0xFF];
10
11const MASK: [u32; 33] = [
12 0x00000000,
13 0x00000001, 0x00000003, 0x00000007, 0x0000000f,
14 0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff,
15 0x000001ff, 0x000003ff, 0x000007ff, 0x00000fff,
16 0x00001fff, 0x00003fff, 0x00007fff, 0x0000ffff,
17 0x0001ffff, 0x0003ffff, 0x0007ffff, 0x000fffff,
18 0x001fffff, 0x003fffff, 0x007fffff, 0x00ffffff,
19 0x01ffffff, 0x03ffffff, 0x07ffffff, 0x0fffffff,
20 0x1fffffff, 0x3fffffff, 0x7fffffff, 0xffffffff
21];
22
23macro_rules! define_worksize_consts {
24 () => {
25 const BITS: usize = Unit::BITS as usize;
26 const ALIGN: usize = BITS / 8;
27 }
28}
29
30macro_rules! define_worksize {
31 (8) => {
32 type Unit = u8;
33 define_worksize_consts!();
34 };
35 (16) => {
36 type Unit = u16;
37 define_worksize_consts!();
38 };
39 (32) => {
40 type Unit = u32;
41 define_worksize_consts!();
42 };
43 (64) => {
44 type Unit = u64;
45 define_worksize_consts!();
46 };
47}
48
49define_worksize!(8);
50
51#[macro_export]
52macro_rules! ilog {
53 ($v:expr) => {
54 {
55 let mut ret = 0;
56 let mut v = $v as u64;
57 while v != 0 {
58 v >>= 1;
59 ret += 1;
60 }
61 ret
62 }
63 }
64}
65
66#[macro_export]
67macro_rules! icount {
68 ($v:expr) => {
69 {
70 let mut ret = 0usize;
71 let mut v = $v as u64;
72 while v != 0 {
73 ret += (v as usize) & 1;
74 v >>= 1;
75 }
76 ret
77 }
78 }
79}
80
81#[derive(Default)]
83pub struct BitReader<'a> {
84 pub endbit: i32,
86
87 pub total_bits: usize,
89
90 pub data: &'a [u8],
92
93 pub cursor: usize,
95}
96
97impl<'a> BitReader<'a> {
98 pub fn new(data: &'a [u8]) -> Self {
103 Self {
104 endbit: 0,
105 total_bits: 0,
106 cursor: 0,
107 data,
108 }
109 }
110
111 pub fn read(&mut self, mut bits: i32) -> io::Result<i32> {
114 if !(0..=32).contains(&bits) {
115 return_Err!(io::Error::new(io::ErrorKind::InvalidInput, format!("Invalid bit number: {bits}")));
116 }
117 let mut ret: i32;
118 let m = MASK[bits as usize];
119 let origbits = bits;
120 let cursor = self.cursor;
121
122 let ptr_index = |mut index: usize| -> io::Result<u8> {
124 index += cursor;
125 let eof_err = || -> io::Error {
126 io::Error::new(io::ErrorKind::UnexpectedEof, format!("UnexpectedEof when trying to read {origbits} bits from the input position 0x{:x}", index))
127 };
128 self.data.get(index).ok_or(eof_err()).copied()
129 };
130
131 bits += self.endbit;
132 if bits == 0 {
133 return Ok(0);
134 }
135
136 ret = (ptr_index(0)? as i32) >> self.endbit;
137 if bits > 8 {
138 ret |= (ptr_index(1)? as i32) << (8 - self.endbit);
139 if bits > 16 {
140 ret |= (ptr_index(2)? as i32) << (16 - self.endbit);
141 if bits > 24 {
142 ret |= (ptr_index(3)? as i32) << (24 - self.endbit);
143 if bits > 32 && self.endbit != 0 {
144 ret |= (ptr_index(4)? as i32) << (32 - self.endbit);
145 }
146 }
147 }
148 }
149 ret &= m as i32;
150 self.cursor += (bits / 8) as usize;
151 self.endbit = bits & 7;
152 self.total_bits += origbits as usize;
153 Ok(ret)
154 }
155}
156
157#[derive(Default)]
159pub struct BitWriter<W>
160where
161 W: Write {
162 pub endbit: i32,
164
165 pub total_bits: usize,
167
168 pub writer: W,
170
171 pub cache: CursorVecU8,
173}
174
175impl<W> BitWriter<W>
176where
177 W: Write {
178 const CACHE_SIZE: usize = 1024;
179
180 pub fn new(writer: W) -> Self {
182 Self {
183 endbit: 0,
184 total_bits: 0,
185 writer,
186 cache: CursorVecU8::default(),
187 }
188 }
189
190 pub fn last_byte(&mut self) -> &mut u8 {
192 if self.cache.is_empty() {
193 self.cache.write_all(&[0u8]).unwrap();
194 }
195 let v = self.cache.get_mut();
196 let len = v.len();
197 &mut v[len - 1]
198 }
199
200 fn write_byte(&mut self, byte: u8) -> io::Result<()> {
202 self.cache.write_all(&[byte])?;
203 if self.cache.len() >= Self::CACHE_SIZE {
204 self.flush()?;
205 }
206 Ok(())
207 }
208
209 pub fn write(&mut self, mut value: u32, mut bits: i32) -> io::Result<()> {
211 if !(0..=32).contains(&bits) {
212 return_Err!(io::Error::new(io::ErrorKind::InvalidInput, format!("Invalid bits {bits}")));
213 }
214 value &= MASK[bits as usize];
215 let origbits = bits;
216 bits += self.endbit;
217
218 *self.last_byte() |= (value << self.endbit) as u8;
219
220 if bits >= 8 {
221 self.write_byte((value >> (8 - self.endbit)) as u8)?;
222 if bits >= 16 {
223 self.write_byte((value >> (16 - self.endbit)) as u8)?;
224 if bits >= 24 {
225 self.write_byte((value >> (24 - self.endbit)) as u8)?;
226 if bits >= 32 {
227 if self.endbit != 0 {
228 self.write_byte((value >> (32 - self.endbit)) as u8)?;
229 } else {
230 self.write_byte(0)?;
231 }
232 }
233 }
234 }
235 }
236
237 self.endbit = bits & 7;
238 self.total_bits += origbits as usize;
239 Ok(())
240 }
241
242 pub fn flush(&mut self) -> io::Result<()> {
243 if self.cache.is_empty() {
244 Ok(())
245 } else if self.endbit == 0 {
246 self.writer.write_all(&self.cache[..])?;
247 self.cache.clear();
248 Ok(())
249 } else {
250 let len = self.cache.len();
251 let last_byte = self.cache[len - 1];
252 self.writer.write_all(&self.cache[..(len - 1)])?;
253 self.cache.clear();
254 self.cache.write_all(&[last_byte])?;
255 Ok(())
256 }
257 }
258
259 pub fn force_flush(&mut self) -> io::Result<()> {
260 self.writer.write_all(&self.cache[..])?;
261 self.cache.clear();
262 self.endbit = 0;
263 Ok(())
264 }
265}
266
267pub type BitWriterCursor = BitWriter<CursorVecU8>;
269
270pub type BitWriterObj = BitWriter<Box<dyn Writer>>;
272
273impl BitWriterCursor {
274 pub fn into_bytes(mut self) -> Vec<u8> {
276 self.force_flush().unwrap();
278 self.writer.into_inner()
279 }
280}
281
282#[macro_export]
284macro_rules! read_bits {
285 ($bitreader:ident, $bits:expr) => {
286 if DEBUG_ON_READ_BITS {
287 $bitreader.read($bits).unwrap()
288 } else {
289 $bitreader.read($bits)?
290 }
291 };
292}
293
294#[macro_export]
296macro_rules! read_f32 {
297 ($bitreader:ident) => {
298 unsafe {std::mem::transmute::<_, f32>(read_bits!($bitreader, 32))}
299 };
300}
301
302#[macro_export]
304macro_rules! write_bits {
305 ($bitwriter:ident, $data:expr, $bits:expr) => {
306 if DEBUG_ON_WRITE_BITS {
307 $bitwriter.write($data as u32, $bits).unwrap()
308 } else {
309 $bitwriter.write($data as u32, $bits)?
310 }
311 };
312}
313
314#[macro_export]
316macro_rules! write_f32 {
317 ($bitwriter:ident, $data:expr) => {
318 write_bits!($bitwriter, unsafe {std::mem::transmute::<_, u32>($data)}, 32)
319 };
320}
321
322#[macro_export]
324macro_rules! read_slice {
325 ($bitreader:ident, $length:expr) => {
326 {
327 let mut ret = Vec::<u8>::with_capacity($length);
328 for _ in 0..$length {
329 ret.push(read_bits!($bitreader, 8) as u8);
330 }
331 ret
332 }
333 };
334}
335
336#[macro_export]
338macro_rules! read_string {
339 ($bitreader:ident, $length:expr) => {
340 {
341 let s = read_slice!($bitreader, $length);
342 match std::str::from_utf8(&s) {
343 Ok(s) => Ok(s.to_string()),
344 Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, format!("Parse UTF-8 failed: {}", String::from_utf8_lossy(&s)))),
345 }
346 }
347 };
348}
349
350#[macro_export]
352macro_rules! write_slice {
353 ($bitwriter:ident, $data:expr) => {
354 for &data in $data.iter() {
355 write_bits!($bitwriter, data, std::mem::size_of_val(&data) as i32 * 8);
356 }
357 };
358}
359
360#[macro_export]
362macro_rules! write_string {
363 ($bitwriter:ident, $string:expr) => {
364 write_slice!($bitwriter, $string.as_bytes());
365 };
366}
367
368pub fn align(size: usize, alignment: usize) -> usize {
370 if size != 0 {
371 ((size - 1) / alignment + 1) * alignment
372 } else {
373 0
374 }
375}
376
377pub fn transmute_vector<S, D>(vector: Vec<S>) -> Vec<D>
380where
381 S: Sized,
382 D: Sized {
383
384 use std::{any::type_name, mem::{size_of, ManuallyDrop}};
385 let s_size = size_of::<S>();
386 let d_size = size_of::<D>();
387 let s_name = type_name::<S>();
388 let d_name = type_name::<D>();
389 let size_in_bytes = s_size * vector.len();
390 let remain_size = size_in_bytes % d_size;
391 if remain_size != 0 {
392 panic!("Could not transmute from Vec<{s_name}> to Vec<{d_name}>: the number of bytes {size_in_bytes} is not divisible to {d_size}.")
393 } else {
394 let mut s = ManuallyDrop::new(vector);
395 unsafe {
396 Vec::<D>::from_raw_parts(s.as_mut_ptr() as *mut D, size_in_bytes / d_size, s.capacity() * s_size / d_size)
397 }
398 }
399}
400
401pub fn shift_data_to_front(data: &[u8], bits: usize, total_bits: usize) -> Vec<u8> {
403 if bits == 0 {
404 data.to_owned()
405 } else if bits >= total_bits {
406 Vec::new()
407 } else {
408 let shifted_total_bits = total_bits - bits;
409 let mut data = {
410 let bytes_moving = bits >> 3;
411 data[bytes_moving..].to_vec()
412 };
413 let bits = bits & 7;
414 if bits == 0 {
415 data
416 } else {
417 data.resize(align(data.len(), ALIGN), 0);
418 let mut to_shift: Vec<Unit> = transmute_vector(data);
419
420 fn combine_bits(data1: Unit, data2: Unit, bits: usize) -> Unit {
421 let move_high = BITS - bits;
422 (data1 >> bits) | (data2 << move_high)
423 }
424
425 for i in 0..(to_shift.len() - 1) {
426 to_shift[i] = combine_bits(to_shift[i], to_shift[i + 1], bits);
427 }
428
429 let last = to_shift.pop().unwrap() >> bits;
430 to_shift.push(last);
431
432 let mut ret = transmute_vector(to_shift);
433 ret.truncate(align(shifted_total_bits, 8) / 8);
434 ret
435 }
436 }
437}
438
439pub fn shift_data_to_back(data: &[u8], bits: usize, total_bits: usize) -> Vec<u8> {
441 if bits == 0 {
442 data.to_owned()
443 } else {
444 let shifted_total_bits = total_bits + bits;
445 let data = {
446 let bytes_added = align(bits, 8) / 8;
447 let data: Vec<u8> = [vec![0u8; bytes_added], data.to_owned()].iter().flatten().copied().collect();
448 data
449 };
450 let bits = bits & 7;
451 if bits == 0 {
452 data
453 } else {
454 let lsh = 8 - bits;
455 shift_data_to_front(&data, lsh, shifted_total_bits + lsh)
456 }
457 }
458}
459
460
461#[derive(Default, Clone, PartialEq, Eq)]
464pub struct BitwiseData {
465 pub data: Vec<u8>,
467
468 pub total_bits: usize,
470}
471
472impl BitwiseData {
473 pub fn new(data: &[u8], total_bits: usize) -> Self {
474 let mut ret = Self {
475 data: data[..Self::calc_total_bytes(total_bits)].to_vec(),
476 total_bits,
477 };
478 ret.remove_residue();
479 ret
480 }
481
482 pub fn from_bytes(data: &[u8]) -> Self {
484 Self {
485 data: data.to_vec(),
486 total_bits: data.len() * 8,
487 }
488 }
489
490 fn remove_residue(&mut self) {
492 let residue_bits = self.total_bits & 7;
493 if residue_bits == 0 {
494 return;
495 }
496 if let Some(byte) = self.data.pop() { self.data.push(byte & MASK8[residue_bits]) }
497 }
498
499 pub fn get_total_bits(&self) -> usize {
501 self.total_bits
502 }
503
504 pub fn get_total_bytes(&self) -> usize {
506 Self::calc_total_bytes(self.total_bits)
507 }
508
509 pub fn calc_total_bytes(total_bits: usize) -> usize {
511 align(total_bits, 8) / 8
512 }
513
514 pub fn fit_to_aligned_size(&mut self) {
516 self.data.resize(align(self.total_bits, BITS) / 8, 0);
517 }
518
519 pub fn shrink_to_fit(&mut self) {
521 self.data.truncate(self.get_total_bytes());
522 self.remove_residue();
523 }
524
525 pub fn is_aligned_size(&self) -> bool {
527 self.data.len() == align(self.data.len(), ALIGN)
528 }
529
530 pub fn split(&self, split_at_bit: usize) -> (Self, Self) {
532 if split_at_bit == 0 {
533 (Self::default(), self.clone())
534 } else if split_at_bit >= self.total_bits {
535 (self.clone(), Self::default())
536 } else {
537 let data1 = {
538 let mut data = self.clone();
539 data.total_bits = split_at_bit;
540 data.shrink_to_fit();
541 let last_bits = data.total_bits & 7;
542 if last_bits != 0 {
543 let last_byte = data.data.pop().unwrap();
544 data.data.push(last_byte & MASK8[last_bits]);
545 }
546 data
547 };
548 let data2 = Self {
549 data: shift_data_to_front(&self.data, split_at_bit, self.total_bits),
550 total_bits: self.total_bits - split_at_bit,
551 };
552 (data1, data2)
553 }
554 }
555
556 pub fn concat(&mut self, rhs: &Self) {
558 if rhs.total_bits == 0 {
559 return;
560 }
561 self.shrink_to_fit();
562 let shifts = self.total_bits & 7;
563 if shifts == 0 {
564 self.data.extend(&rhs.data);
565 } else {
566 let shift_left = 8 - shifts;
567 let last_byte = self.data.pop().unwrap();
568 self.data.push(last_byte | (rhs.data[0] << shifts));
569 self.data.extend(shift_data_to_front(&rhs.data, shift_left, rhs.total_bits));
570 }
571 self.total_bits += rhs.total_bits;
572 }
573
574 pub fn into_bytes(mut self) -> Vec<u8> {
576 self.shrink_to_fit();
577 self.data
578 }
579}
580
581impl Debug for BitwiseData {
582 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
583 f.debug_struct("BitwiseData")
584 .field("data", &format_args!("{}", format_array!(self.data, hex2)))
585 .field("total_bits", &self.total_bits)
586 .finish()
587 }
588}