pub struct Buffer<'a> {
buf: &'a [u8],
index: usize,
bit_pos: usize,
}
impl<'a> Buffer<'a> {
pub fn from_slice(buf: &'a [u8]) -> Self {
Self {
buf,
index: 0,
bit_pos: 0,
}
}
pub fn seek_bits(&mut self, cut: usize) {
for _ in 0..cut {
self.advance();
}
}
pub fn get_bytes(&mut self, count: usize) -> &[u8] {
assert_eq!(self.bit_pos, 0, "get_bytes requires byte alignment");
self.index += count;
&self.buf[self.index - count..self.index]
}
pub fn get_bit(&mut self) -> bool {
self.next()
}
pub fn get_bits(&mut self, count: usize) -> u32 {
assert!(count > 0 && count <= 32, "count must be in [1, 32]");
let mut aac = 0;
for i in 0..count {
aac |= (self.get_bit() as u32) << (count - i - 1);
}
aac
}
pub fn get_uvlc(&mut self) -> u32 {
let mut lz = 0;
loop {
if self.get_bit() {
break;
}
lz += 1;
}
if lz >= 32 {
0xFFFFFFFF
} else {
self.get_bits(lz) + (1 << lz) - 1
}
}
pub fn get_le(&mut self, count: usize) -> u32 {
assert_eq!(self.bit_pos, 0, "get_le requires byte alignment");
let mut t = 0;
for i in 0..count {
t += self.get_bits(8) << (i * 8);
}
t
}
pub fn get_leb128(&mut self) -> u64 {
assert_eq!(self.bit_pos, 0, "get_leb128 requires byte alignment");
let mut value: u64 = 0;
for i in 0..8u64 {
let byte = self.get_bits(8) as u64;
value |= (byte & 0x7f) << (i * 7);
if byte & 0x80 == 0 {
break;
}
}
value
}
pub fn get_su(&mut self, count: usize) -> i32 {
let value = self.get_bits(count) as i32;
let sign_mask = 1i32 << (count - 1);
if value & sign_mask != 0 {
value - 2 * sign_mask
} else {
value
}
}
pub fn get_ns(&mut self, n: u32) -> u32 {
if n <= 1 {
return 0;
}
let w = (32 - n.leading_zeros()) as usize;
let m = (1u32 << w) - n;
let v = self.get_bits(w - 1);
if v < m {
v
} else {
let extra_bit = self.get_bit() as u32;
(v << 1) - m + extra_bit
}
}
pub fn is_byte_aligned(&self) -> bool {
self.bit_pos == 0
}
pub fn byte_align(&mut self) {
if self.bit_pos != 0 {
self.seek_bits(8 - self.bit_pos);
}
}
pub fn bytes_remaining(&self) -> usize {
if self.index >= self.buf.len() {
return 0;
}
self.buf.len() - self.index
}
pub fn bytes_consumed(&self) -> usize {
self.index + if self.bit_pos > 0 { 1 } else { 0 }
}
}
impl<'a> Buffer<'a> {
fn advance(&mut self) {
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
if self.index < self.buf.len() {
self.index += 1;
}
}
}
fn next(&mut self) -> bool {
let curr_byte = self.buf[self.index];
let shift = 7 - self.bit_pos;
let bit = curr_byte & (1 << shift);
self.advance();
(bit >> shift) == 1
}
}
impl<'a> AsMut<Buffer<'a>> for Buffer<'a> {
fn as_mut(&mut self) -> &mut Self {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_bit() {
let data = [0xB2u8];
let mut buf = Buffer::from_slice(&data);
assert_eq!(buf.get_bit(), true); assert_eq!(buf.get_bit(), false); assert_eq!(buf.get_bit(), true); assert_eq!(buf.get_bit(), true); assert_eq!(buf.get_bit(), false); assert_eq!(buf.get_bit(), false); assert_eq!(buf.get_bit(), true); assert_eq!(buf.get_bit(), false); }
#[test]
fn test_get_bits() {
let data = [0xABu8, 0xCDu8]; let mut buf = Buffer::from_slice(&data);
assert_eq!(buf.get_bits(4), 0xA); assert_eq!(buf.get_bits(4), 0xB); assert_eq!(buf.get_bits(8), 0xCD); }
#[test]
fn test_get_leb128() {
let data = [0x05u8];
let mut buf = Buffer::from_slice(&data);
assert_eq!(buf.get_leb128(), 5);
let data2 = [0x80u8, 0x01u8];
let mut buf2 = Buffer::from_slice(&data2);
assert_eq!(buf2.get_leb128(), 128);
}
#[test]
fn test_get_su() {
let data = [0b1100_0000u8];
let mut buf = Buffer::from_slice(&data);
assert_eq!(buf.get_su(4), -4);
}
#[test]
fn test_get_ns() {
let data = [0b00_01_10_11u8];
let mut buf = Buffer::from_slice(&data);
assert_eq!(buf.get_ns(4), 0); assert_eq!(buf.get_ns(4), 1); assert_eq!(buf.get_ns(4), 2); assert_eq!(buf.get_ns(4), 3); }
#[test]
fn test_byte_align() {
let data = [0xFFu8, 0xAAu8];
let mut buf = Buffer::from_slice(&data);
buf.get_bits(3);
assert!(!buf.is_byte_aligned());
buf.byte_align();
assert!(buf.is_byte_aligned());
assert_eq!(buf.get_bits(8), 0xAA);
}
}