revorbis/
bitwise.rs

1use std::{
2    io::{self, Write},
3    fmt::{self, Debug, Formatter},
4};
5
6use crate::*;
7use io_utils::{Writer, CursorVecU8};
8
9const MASK8: [u8; 9] = [0x00, 0x01, 0x03, 0x07, 0x0F, 0x1F, 0x3F, 0x7F, 0xFF];
10
11const MASK: [u32; 33] = [
12    0x00000000,
13    0x00000001, 0x00000003, 0x00000007, 0x0000000f,
14    0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff,
15    0x000001ff, 0x000003ff, 0x000007ff, 0x00000fff,
16    0x00001fff, 0x00003fff, 0x00007fff, 0x0000ffff,
17    0x0001ffff, 0x0003ffff, 0x0007ffff, 0x000fffff,
18    0x001fffff, 0x003fffff, 0x007fffff, 0x00ffffff,
19    0x01ffffff, 0x03ffffff, 0x07ffffff, 0x0fffffff,
20    0x1fffffff, 0x3fffffff, 0x7fffffff, 0xffffffff
21];
22
23macro_rules! define_worksize_consts {
24    () => {
25        const BITS: usize = Unit::BITS as usize;
26        const ALIGN: usize = BITS / 8;
27    }
28}
29
30macro_rules! define_worksize {
31    (8) => {
32        type  Unit = u8;
33        define_worksize_consts!();
34    };
35    (16) => {
36        type  Unit = u16;
37        define_worksize_consts!();
38    };
39    (32) => {
40        type  Unit = u32;
41        define_worksize_consts!();
42    };
43    (64) => {
44        type  Unit = u64;
45        define_worksize_consts!();
46    };
47}
48
49define_worksize!(8);
50
51#[macro_export]
52macro_rules! ilog {
53    ($v:expr) => {
54        {
55            let mut ret = 0;
56            let mut v = $v as u64;
57            while v != 0 {
58                v >>= 1;
59                ret += 1;
60            }
61            ret
62        }
63    }
64}
65
66#[macro_export]
67macro_rules! icount {
68    ($v:expr) => {
69        {
70            let mut ret = 0usize;
71            let mut v = $v as u64;
72            while v != 0 {
73                ret += (v as usize) & 1;
74                v >>= 1;
75            }
76            ret
77        }
78    }
79}
80
81/// * BitReader: read vorbis data bit by bit
82#[derive(Default)]
83pub struct BitReader<'a> {
84    /// * Currently ends at which bit in the last byte
85    pub endbit: i32,
86
87    /// * How many bits did we read in total
88    pub total_bits: usize,
89
90    /// * Borrowed a slice of data
91    pub data: &'a [u8],
92
93    /// * Current byte index
94    pub cursor: usize,
95}
96
97impl<'a> BitReader<'a> {
98    /// * `data` is decapsulated from the Ogg stream
99    /// * `cursor` is the read position of the `BitReader`
100    /// * Pass `data` as a slice that begins from the part you want to read,
101    ///   Then you'll get the `cursor` to indicate how many bytes this part of data takes.
102    pub fn new(data: &'a [u8]) -> Self {
103        Self {
104            endbit: 0,
105            total_bits: 0,
106            cursor: 0,
107            data,
108        }
109    }
110
111    /// * Read data bit by bit
112    /// * bits <= 32
113    pub fn read(&mut self, mut bits: i32) -> io::Result<i32> {
114        if !(0..=32).contains(&bits) {
115            return_Err!(io::Error::new(io::ErrorKind::InvalidInput, format!("Invalid bit number: {bits}")));
116        }
117        let mut ret: i32;
118        let m = MASK[bits as usize];
119        let origbits = bits;
120        let cursor = self.cursor;
121
122        // Don't want it panic, and don't want an Option.
123        let ptr_index = |mut index: usize| -> io::Result<u8> {
124            index += cursor;
125            let eof_err = || -> io::Error {
126                io::Error::new(io::ErrorKind::UnexpectedEof, format!("UnexpectedEof when trying to read {origbits} bits from the input position 0x{:x}", index))
127            };
128            self.data.get(index).ok_or(eof_err()).copied()
129        };
130
131        bits += self.endbit;
132        if bits == 0 {
133            return Ok(0);
134        }
135
136        ret = (ptr_index(0)? as i32) >> self.endbit;
137        if bits > 8 {
138            ret |= (ptr_index(1)? as i32) << (8 - self.endbit);
139            if bits > 16 {
140                ret |= (ptr_index(2)? as i32) << (16 - self.endbit);
141                if bits > 24 {
142                    ret |= (ptr_index(3)? as i32) << (24 - self.endbit);
143                    if bits > 32 && self.endbit != 0 {
144                        ret |= (ptr_index(4)? as i32) << (32 - self.endbit);
145                    }
146                }
147            }
148        }
149        ret &= m as i32;
150        self.cursor += (bits / 8) as usize;
151        self.endbit = bits & 7;
152        self.total_bits += origbits as usize;
153        Ok(ret)
154    }
155}
156
157/// * BitWriter: write vorbis data bit by bit
158#[derive(Default)]
159pub struct BitWriter<W>
160where
161    W: Write {
162    /// * Currently ends at which bit in the last byte
163    pub endbit: i32,
164
165    /// * How many bits did we wrote in total
166    pub total_bits: usize,
167
168    /// * The sink
169    pub writer: W,
170
171    /// * The cache that holds data to be flushed
172    pub cache: CursorVecU8,
173}
174
175impl<W> BitWriter<W>
176where
177    W: Write {
178    const CACHE_SIZE: usize = 1024;
179
180    /// * Create a `CursorVecU8` to write
181    pub fn new(writer: W) -> Self {
182        Self {
183            endbit: 0,
184            total_bits: 0,
185            writer,
186            cache: CursorVecU8::default(),
187        }
188    }
189
190    /// * Get the last byte for modifying it
191    pub fn last_byte(&mut self) -> &mut u8 {
192        if self.cache.is_empty() {
193            self.cache.write_all(&[0u8]).unwrap();
194        }
195        let v = self.cache.get_mut();
196        let len = v.len();
197        &mut v[len - 1]
198    }
199
200    /// * Write data by bytes one by one
201    fn write_byte(&mut self, byte: u8) -> io::Result<()> {
202        self.cache.write_all(&[byte])?;
203        if self.cache.len() >= Self::CACHE_SIZE {
204            self.flush()?;
205        }
206        Ok(())
207    }
208
209    /// * Write data in bits, max is 32 bit.
210    pub fn write(&mut self, mut value: u32, mut bits: i32) -> io::Result<()> {
211        if !(0..=32).contains(&bits) {
212            return_Err!(io::Error::new(io::ErrorKind::InvalidInput, format!("Invalid bits {bits}")));
213        }
214        value &= MASK[bits as usize];
215        let origbits = bits;
216        bits += self.endbit;
217
218        *self.last_byte() |= (value << self.endbit) as u8;
219
220        if bits >= 8 {
221            self.write_byte((value >> (8 - self.endbit)) as u8)?;
222            if bits >= 16 {
223                self.write_byte((value >> (16 - self.endbit)) as u8)?;
224                if bits >= 24 {
225                    self.write_byte((value >> (24 - self.endbit)) as u8)?;
226                    if bits >= 32 {
227                        if self.endbit != 0 {
228                            self.write_byte((value >> (32 - self.endbit)) as u8)?;
229                        } else {
230                            self.write_byte(0)?;
231                        }
232                    }
233                }
234            }
235        }
236
237        self.endbit = bits & 7;
238        self.total_bits += origbits as usize;
239        Ok(())
240    }
241
242    pub fn flush(&mut self) -> io::Result<()> {
243        if self.cache.is_empty() {
244            Ok(())
245        } else if self.endbit == 0 {
246            self.writer.write_all(&self.cache[..])?;
247            self.cache.clear();
248            Ok(())
249        } else {
250            let len = self.cache.len();
251            let last_byte = self.cache[len - 1];
252            self.writer.write_all(&self.cache[..(len - 1)])?;
253            self.cache.clear();
254            self.cache.write_all(&[last_byte])?;
255            Ok(())
256        }
257    }
258
259    pub fn force_flush(&mut self) -> io::Result<()> {
260        self.writer.write_all(&self.cache[..])?;
261        self.cache.clear();
262        self.endbit = 0;
263        Ok(())
264    }
265}
266
267/// * The specialized `BitWriter` that uses `CursorVecU8>` as its sink.
268pub type BitWriterCursor = BitWriter<CursorVecU8>;
269
270/// * The specialized `BitWriter` that uses `Box<dyn Writer>` as its sink.
271pub type BitWriterObj = BitWriter<Box<dyn Writer>>;
272
273impl BitWriterCursor {
274    /// * Get the inner byte array and consumes the writer.
275    pub fn into_bytes(mut self) -> Vec<u8> {
276        // Make sure the last byte was written
277        self.force_flush().unwrap();
278        self.writer.into_inner()
279    }
280}
281
282/// * Read bits of data using `BitReader`
283#[macro_export]
284macro_rules! read_bits {
285    ($bitreader:ident, $bits:expr) => {
286        if DEBUG_ON_READ_BITS {
287            $bitreader.read($bits).unwrap()
288        } else {
289            $bitreader.read($bits)?
290        }
291    };
292}
293
294/// * Read a `f32` using `BitReader`
295#[macro_export]
296macro_rules! read_f32 {
297    ($bitreader:ident) => {
298        unsafe {std::mem::transmute::<_, f32>(read_bits!($bitreader, 32))}
299    };
300}
301
302/// * Write bits of data using `BitWriter<W>`
303#[macro_export]
304macro_rules! write_bits {
305    ($bitwriter:ident, $data:expr, $bits:expr) => {
306        if DEBUG_ON_WRITE_BITS {
307            $bitwriter.write($data as u32, $bits).unwrap()
308        } else {
309            $bitwriter.write($data as u32, $bits)?
310        }
311    };
312}
313
314/// * Write a `f32` using `BitWriter<W>`
315#[macro_export]
316macro_rules! write_f32 {
317    ($bitwriter:ident, $data:expr) => {
318        write_bits!($bitwriter, unsafe {std::mem::transmute::<_, u32>($data)}, 32)
319    };
320}
321
322/// * Read a byte array `slice` using the `BitReader`
323#[macro_export]
324macro_rules! read_slice {
325    ($bitreader:ident, $length:expr) => {
326        {
327            let mut ret = Vec::<u8>::with_capacity($length);
328            for _ in 0..$length {
329                ret.push(read_bits!($bitreader, 8) as u8);
330            }
331            ret
332        }
333    };
334}
335
336/// * Read a sized string using the `BitReader`
337#[macro_export]
338macro_rules! read_string {
339    ($bitreader:ident, $length:expr) => {
340        {
341            let s = read_slice!($bitreader, $length);
342            match std::str::from_utf8(&s) {
343                Ok(s) => Ok(s.to_string()),
344                Err(_) => Err(io::Error::new(io::ErrorKind::InvalidData, format!("Parse UTF-8 failed: {}", String::from_utf8_lossy(&s)))),
345            }
346        }
347    };
348}
349
350/// * Write a slice to the `BitWriter`
351#[macro_export]
352macro_rules! write_slice {
353    ($bitwriter:ident, $data:expr) => {
354        for &data in $data.iter() {
355            write_bits!($bitwriter, data, std::mem::size_of_val(&data) as i32 * 8);
356        }
357    };
358}
359
360/// * Write a sized string to the `BitWriter`
361#[macro_export]
362macro_rules! write_string {
363    ($bitwriter:ident, $string:expr) => {
364        write_slice!($bitwriter, $string.as_bytes());
365    };
366}
367
368/// * Alignment calculation
369pub fn align(size: usize, alignment: usize) -> usize {
370    if size != 0 {
371        ((size - 1) / alignment + 1) * alignment
372    } else {
373        0
374    }
375}
376
377/// * Transmute vector, change its type, but not by cloning it or changing its memory location or capacity.
378/// * Will panic or crash if you don't know what you are doing.
379pub fn transmute_vector<S, D>(vector: Vec<S>) -> Vec<D>
380where
381    S: Sized,
382    D: Sized {
383
384    use std::{any::type_name, mem::{size_of, ManuallyDrop}};
385    let s_size = size_of::<S>();
386    let d_size = size_of::<D>();
387    let s_name = type_name::<S>();
388    let d_name = type_name::<D>();
389    let size_in_bytes = s_size * vector.len();
390    let remain_size = size_in_bytes % d_size;
391    if remain_size != 0 {
392        panic!("Could not transmute from Vec<{s_name}> to Vec<{d_name}>: the number of bytes {size_in_bytes} is not divisible to {d_size}.")
393    } else {
394        let mut s = ManuallyDrop::new(vector);
395        unsafe {
396            Vec::<D>::from_raw_parts(s.as_mut_ptr() as *mut D, size_in_bytes / d_size, s.capacity() * s_size / d_size)
397        }
398    }
399}
400
401/// * Shift an array of bits to the front. In a byte, the lower bits are the front bits.
402pub fn shift_data_to_front(data: &[u8], bits: usize, total_bits: usize) -> Vec<u8> {
403    if bits == 0 {
404        data.to_owned()
405    } else if bits >= total_bits {
406        Vec::new()
407    } else {
408        let shifted_total_bits = total_bits - bits;
409        let mut data = {
410            let bytes_moving = bits >> 3;
411            data[bytes_moving..].to_vec()
412        };
413        let bits = bits & 7;
414        if bits == 0 {
415            data
416        } else {
417            data.resize(align(data.len(), ALIGN), 0);
418            let mut to_shift: Vec<Unit> = transmute_vector(data);
419
420            fn combine_bits(data1: Unit, data2: Unit, bits: usize) -> Unit {
421                let move_high = BITS - bits;
422                (data1 >> bits) | (data2 << move_high)
423            }
424
425            for i in 0..(to_shift.len() - 1) {
426                to_shift[i] = combine_bits(to_shift[i], to_shift[i + 1], bits);
427            }
428
429            let last = to_shift.pop().unwrap() >> bits;
430            to_shift.push(last);
431
432            let mut ret = transmute_vector(to_shift);
433            ret.truncate(align(shifted_total_bits, 8) / 8);
434            ret
435        }
436    }
437}
438
439/// * Shift an array of bits to the back. In a byte, the higher bits are the back bits.
440pub fn shift_data_to_back(data: &[u8], bits: usize, total_bits: usize) -> Vec<u8> {
441    if bits == 0 {
442        data.to_owned()
443    } else {
444        let shifted_total_bits = total_bits + bits;
445        let data = {
446            let bytes_added = align(bits, 8) / 8;
447            let data: Vec<u8> = [vec![0u8; bytes_added], data.to_owned()].iter().flatten().copied().collect();
448            data
449        };
450        let bits = bits & 7;
451        if bits == 0 {
452            data
453        } else {
454            let lsh = 8 - bits;
455            shift_data_to_front(&data, lsh, shifted_total_bits + lsh)
456        }
457    }
458}
459
460
461/// * A utility for you to manipulate data bitwise, mainly to concatenate data in bits or to split data from a specific bit position.
462/// * This is mainly used for Vorbis data parsing.
463#[derive(Default, Clone, PartialEq, Eq)]
464pub struct BitwiseData {
465    /// * Store as bytes
466    pub data: Vec<u8>,
467
468    /// * The total bits of the books
469    pub total_bits: usize,
470}
471
472impl BitwiseData {
473    pub fn new(data: &[u8], total_bits: usize) -> Self {
474        let mut ret = Self {
475            data: data[..Self::calc_total_bytes(total_bits)].to_vec(),
476            total_bits,
477        };
478        ret.remove_residue();
479        ret
480    }
481
482    /// * Construct from bytes
483    pub fn from_bytes(data: &[u8]) -> Self {
484        Self {
485            data: data.to_vec(),
486            total_bits: data.len() * 8,
487        }
488    }
489
490    /// * If there are any `1` bits outside of the byte array, erase them to zeros.
491    fn remove_residue(&mut self) {
492        let residue_bits = self.total_bits & 7;
493        if residue_bits == 0 {
494            return;
495        }
496        if let Some(byte) = self.data.pop() { self.data.push(byte & MASK8[residue_bits]) }
497    }
498
499    /// * Get the number of total bits in the `data` field
500    pub fn get_total_bits(&self) -> usize {
501        self.total_bits
502    }
503
504    /// * Get the number of bytes that are just enough to contain all of the bits.
505    pub fn get_total_bytes(&self) -> usize {
506        Self::calc_total_bytes(self.total_bits)
507    }
508
509    /// * Get the number of bytes that are just enough to contain all of the bits.
510    pub fn calc_total_bytes(total_bits: usize) -> usize {
511        align(total_bits, 8) / 8
512    }
513
514    /// * Resize to the aligned size. Doing this is for `shift_data_to_front()` and `shift_data_to_back()` to manipulate bits efficiently.
515    pub fn fit_to_aligned_size(&mut self) {
516        self.data.resize(align(self.total_bits, BITS) / 8, 0);
517    }
518
519    /// * Resize to the number of bytes that are just enough to contain all of the bits.
520    pub fn shrink_to_fit(&mut self) {
521        self.data.truncate(self.get_total_bytes());
522        self.remove_residue();
523    }
524
525    /// * Check if the data length is just the aligned size.
526    pub fn is_aligned_size(&self) -> bool {
527        self.data.len() == align(self.data.len(), ALIGN)
528    }
529
530    /// * Breakdown to 2 parts of the data at the specific bitvise position.
531    pub fn split(&self, split_at_bit: usize) -> (Self, Self) {
532        if split_at_bit == 0 {
533            (Self::default(), self.clone())
534        } else if split_at_bit >= self.total_bits {
535            (self.clone(), Self::default())
536        } else {
537            let data1 = {
538                let mut data = self.clone();
539                data.total_bits = split_at_bit;
540                data.shrink_to_fit();
541                let last_bits = data.total_bits & 7;
542                if last_bits != 0 {
543                    let last_byte = data.data.pop().unwrap();
544                    data.data.push(last_byte & MASK8[last_bits]);
545                }
546                data
547            };
548            let data2 = Self {
549                data: shift_data_to_front(&self.data, split_at_bit, self.total_bits),
550                total_bits: self.total_bits - split_at_bit,
551            };
552            (data1, data2)
553        }
554    }
555
556    /// * Concat another `BitwiseData` to the bitstream, without the gap.
557    pub fn concat(&mut self, rhs: &Self) {
558        if rhs.total_bits == 0 {
559            return;
560        }
561        self.shrink_to_fit();
562        let shifts = self.total_bits & 7;
563        if shifts == 0 {
564            self.data.extend(&rhs.data);
565        } else {
566            let shift_left = 8 - shifts;
567            let last_byte = self.data.pop().unwrap();
568            self.data.push(last_byte | (rhs.data[0] << shifts));
569            self.data.extend(shift_data_to_front(&rhs.data, shift_left, rhs.total_bits));
570        }
571        self.total_bits += rhs.total_bits;
572    }
573
574    /// * Turn to byte array
575    pub fn into_bytes(mut self) -> Vec<u8> {
576        self.shrink_to_fit();
577        self.data
578    }
579}
580
581impl Debug for BitwiseData {
582    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
583        f.debug_struct("BitwiseData")
584        .field("data", &format_args!("{}", format_array!(self.data, hex2)))
585        .field("total_bits", &self.total_bits)
586        .finish()
587    }
588}