use super::{Jp2Error, Jp2Result};
pub struct J2kBitWriter {
out: Vec<u8>,
cur: u8,
nbits: u8,
capacity: u8,
last_was_ff: bool,
}
impl Default for J2kBitWriter {
fn default() -> Self {
Self::new()
}
}
impl J2kBitWriter {
#[must_use]
pub fn new() -> Self {
Self {
out: Vec::new(),
cur: 0,
nbits: 0,
capacity: 8,
last_was_ff: false,
}
}
pub fn write_bit(&mut self, bit: u8) {
let shift = self.capacity - 1 - self.nbits;
self.cur |= (bit & 1) << shift;
self.nbits += 1;
if self.nbits == self.capacity {
self.flush_byte();
}
}
pub fn write_bits(&mut self, value: u32, n: u8) {
let mut i = n;
while i > 0 {
i -= 1;
let bit = ((value >> i) & 1) as u8;
self.write_bit(bit);
}
}
fn flush_byte(&mut self) {
let byte = self.cur;
self.out.push(byte);
self.last_was_ff = byte == 0xFF;
self.cur = 0;
self.nbits = 0;
self.capacity = if self.last_was_ff { 7 } else { 8 };
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
if self.nbits > 0 {
self.flush_byte();
}
self.out
}
#[must_use]
pub fn bit_len(&self) -> usize {
self.out.len() * 8 + usize::from(self.nbits)
}
}
fn write_block_length(writer: &mut J2kBitWriter, value: usize) -> Jp2Result<()> {
writer.write_bit(0);
let lblock: u32 = 3;
let mut total_bits = lblock;
while (1u64 << total_bits) <= value as u64 {
total_bits += 1;
if total_bits > 30 {
return Err(Jp2Error::InternalError(
"Tier-2 block length exceeds 30 bits".to_string(),
));
}
}
let extra = total_bits - lblock;
for _ in 0..extra {
writer.write_bit(1);
}
writer.write_bit(0);
writer.write_bits(value as u32, total_bits as u8);
Ok(())
}
pub fn assemble_packet(block_streams: &[Vec<u8>]) -> Jp2Result<Vec<u8>> {
let any_included = block_streams.iter().any(|b| !b.is_empty());
let mut writer = J2kBitWriter::new();
if !any_included {
writer.write_bit(0);
return Ok(writer.finish());
}
writer.write_bit(1);
for stream in block_streams {
let included = !stream.is_empty();
writer.write_bit(u8::from(included));
if included {
writer.write_bit(0);
}
}
for stream in block_streams {
if !stream.is_empty() {
write_block_length(&mut writer, stream.len())?;
}
}
let mut tile_data = writer.finish();
for stream in block_streams {
if !stream.is_empty() {
tile_data.extend_from_slice(stream);
}
}
Ok(tile_data)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jpeg2000::bitreader::J2kBitReader;
use crate::jpeg2000::tier2::parse_packet_header;
#[test]
fn bitwriter_roundtrip_simple() {
let mut w = J2kBitWriter::new();
w.write_bit(1);
w.write_bit(0);
w.write_bits(0b1011, 4);
w.write_bits(0xAB, 8);
let bytes = w.finish();
let mut r = J2kBitReader::new(&bytes);
assert_eq!(r.read_bit().expect("bit"), 1);
assert_eq!(r.read_bit().expect("bit"), 0);
assert_eq!(r.read_bits(4).expect("bits"), 0b1011);
assert_eq!(r.read_bits(8).expect("bits"), 0xAB);
}
#[test]
fn bitwriter_stuffing_after_ff() {
let mut w = J2kBitWriter::new();
w.write_bits(0xFF, 8);
w.write_bits(0b101, 3);
let bytes = w.finish();
assert_eq!(bytes[0], 0xFF);
assert!(bytes[1] <= 0x7F, "stuffed byte must have top bit 0");
let mut r = J2kBitReader::new(&bytes);
assert_eq!(r.read_bits(8).expect("bits"), 0xFF);
assert_eq!(r.read_bits(3).expect("bits"), 0b101);
}
#[test]
fn empty_packet_decodes_all_excluded() {
let streams: Vec<Vec<u8>> = vec![Vec::new(); 5];
let data = assemble_packet(&streams).expect("assemble");
let mut r = J2kBitReader::new(&data);
let header = parse_packet_header(&mut r, 5).expect("parse");
assert!(header.included_blocks.iter().all(|&b| !b));
}
#[test]
fn single_included_block_lengths_match() {
let streams: Vec<Vec<u8>> = vec![vec![1u8, 2, 3], Vec::new(), vec![9u8; 200]];
let data = assemble_packet(&streams).expect("assemble");
let mut r = J2kBitReader::new(&data);
let header = parse_packet_header(&mut r, 3).expect("parse");
assert_eq!(header.included_blocks, vec![true, false, true]);
assert_eq!(header.data_lengths[0], 3);
assert_eq!(header.data_lengths[2], 200);
}
#[test]
fn lengths_various_sizes() {
for &len in &[1usize, 7, 8, 15, 16, 100, 255, 256, 1000, 65535] {
let streams = vec![vec![0u8; len]];
let data = assemble_packet(&streams).expect("assemble");
let mut r = J2kBitReader::new(&data);
let header = parse_packet_header(&mut r, 1).expect("parse");
assert!(header.included_blocks[0]);
assert_eq!(header.data_lengths[0], len, "length {len} round-trip");
}
}
}