use super::error::{JpegError, Result};
pub struct BitReader<'a> {
data: &'a [u8],
pos: usize,
buf: u32,
bits_left: u8,
marker_found: Option<u8>,
}
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8], pos: usize) -> Self {
Self {
data,
pos,
buf: 0,
bits_left: 0,
marker_found: None,
}
}
pub fn read_bits(&mut self, count: u8) -> Result<u16> {
debug_assert!((1..=16).contains(&count));
while self.bits_left < count {
self.fill_byte()?;
}
self.bits_left -= count;
let val = (self.buf >> self.bits_left) & ((1u32 << count) - 1);
Ok(val as u16)
}
pub fn peek_bits(&mut self, count: u8) -> Result<u16> {
debug_assert!((1..=16).contains(&count));
while self.bits_left < count {
self.fill_byte()?;
}
let val = (self.buf >> (self.bits_left - count)) & ((1u32 << count) - 1);
Ok(val as u16)
}
pub fn skip_bits(&mut self, count: u8) {
debug_assert!(count <= self.bits_left);
self.bits_left -= count;
}
pub fn byte_align(&mut self) {
self.bits_left = 0;
self.buf = 0;
}
pub fn position(&self) -> usize {
self.pos
}
pub fn marker_found(&self) -> Option<u8> {
self.marker_found
}
pub fn check_restart_marker(&mut self) -> Result<Option<u8>> {
self.byte_align();
if let Some(m) = self.marker_found
&& (m & 0xF8) == 0xD0 {
self.marker_found = None;
return Ok(Some(m & 0x07));
}
while self.pos + 1 < self.data.len() && self.data[self.pos] == 0xFF {
let next = self.data[self.pos + 1];
if next == 0xFF {
self.pos += 1;
continue;
}
if (next & 0xF8) == 0xD0 {
let rst = next & 0x07;
self.pos += 2;
return Ok(Some(rst));
}
break;
}
Ok(None)
}
fn fill_byte(&mut self) -> Result<()> {
if self.pos >= self.data.len() {
return Err(JpegError::UnexpectedEof);
}
let byte = self.data[self.pos];
self.pos += 1;
if byte == 0xFF {
if self.pos >= self.data.len() {
return Err(JpegError::UnexpectedEof);
}
let next = self.data[self.pos];
if next == 0x00 {
self.pos += 1;
} else {
self.marker_found = Some(next);
self.pos += 1;
self.buf = (self.buf << 8) | 0xFF;
self.bits_left += 8;
return Ok(());
}
}
self.buf = (self.buf << 8) | (byte as u32);
self.bits_left += 8;
Ok(())
}
}
pub struct BitWriter {
output: Vec<u8>,
buf: u8,
bits_used: u8,
}
impl Default for BitWriter {
fn default() -> Self {
Self::new()
}
}
impl BitWriter {
pub fn new() -> Self {
Self {
output: Vec::new(),
buf: 0,
bits_used: 0,
}
}
pub fn write_bits(&mut self, value: u16, count: u8) {
debug_assert!((1..=16).contains(&count));
for i in (0..count).rev() {
let bit = (value >> i) & 1;
self.buf = (self.buf << 1) | (bit as u8);
self.bits_used += 1;
if self.bits_used == 8 {
self.emit_byte(self.buf);
self.buf = 0;
self.bits_used = 0;
}
}
}
pub fn flush(mut self) -> Vec<u8> {
if self.bits_used > 0 {
let remaining = 8 - self.bits_used;
self.buf = (self.buf << remaining) | ((1u8 << remaining) - 1);
self.emit_byte(self.buf);
}
self.output
}
fn emit_byte(&mut self, byte: u8) {
self.output.push(byte);
if byte == 0xFF {
self.output.push(0x00); }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_basic_bits() {
let data = [0xA5];
let mut r = BitReader::new(&data, 0);
assert_eq!(r.read_bits(4).unwrap(), 0b1010);
assert_eq!(r.read_bits(4).unwrap(), 0b0101);
}
#[test]
fn read_cross_byte() {
let data = [0xFF, 0x00, 0x80];
let mut r = BitReader::new(&data, 0);
assert_eq!(r.read_bits(12).unwrap(), 0xFF8); }
#[test]
fn byte_stuffing_decode() {
let data = [0xFF, 0x00];
let mut r = BitReader::new(&data, 0);
assert_eq!(r.read_bits(8).unwrap(), 0xFF);
}
#[test]
fn marker_detection() {
let data = [0xAB, 0xFF, 0xD9];
let mut r = BitReader::new(&data, 0);
assert_eq!(r.read_bits(8).unwrap(), 0xAB);
let _ = r.read_bits(8);
assert_eq!(r.marker_found(), Some(0xD9));
}
#[test]
fn write_basic() {
let mut w = BitWriter::new();
w.write_bits(0b1010, 4);
w.write_bits(0b0101, 4);
let out = w.flush();
assert_eq!(out, vec![0xA5]);
}
#[test]
fn write_byte_stuffing() {
let mut w = BitWriter::new();
w.write_bits(0xFF, 8);
let out = w.flush();
assert_eq!(out, vec![0xFF, 0x00]);
}
#[test]
fn write_padding() {
let mut w = BitWriter::new();
w.write_bits(0b110, 3);
let out = w.flush();
assert_eq!(out, vec![0xDF]);
}
#[test]
fn write_cross_byte() {
let mut w = BitWriter::new();
w.write_bits(0b1111_1111_1000, 12);
let out = w.flush();
assert_eq!(out, vec![0xFF, 0x00, 0x8F]);
}
#[test]
fn peek_then_skip() {
let data = [0xA5]; let mut r = BitReader::new(&data, 0);
assert_eq!(r.peek_bits(4).unwrap(), 0b1010);
r.skip_bits(4);
assert_eq!(r.read_bits(4).unwrap(), 0b0101);
}
}