use crate::error::{Error, Result};
use super::data::*;
#[rustfmt::skip]
const ZIGZAG: [usize; 64] = [
0, 1, 8, 16, 9, 2, 3, 10,
17, 24, 32, 25, 18, 11, 4, 5,
12, 19, 26, 33, 40, 48, 41, 34,
27, 20, 13, 6, 7, 14, 21, 28,
35, 42, 49, 56, 57, 50, 43, 36,
29, 22, 15, 23, 30, 37, 44, 51,
58, 59, 52, 45, 38, 31, 39, 46,
53, 60, 61, 54, 47, 55, 62, 63,
];
pub fn write_jpeg(jpeg: &JpegData) -> Result<Vec<u8>> {
let mut out = Vec::new();
{
let mut writer = JpegWriter::new(&mut out);
writer.write_marker(0xD8)?;
let mut app_idx = 0usize;
let mut com_idx = 0usize;
let mut scan_idx = 0usize;
let mut dqt_idx = 0usize;
let mut dht_idx = 0usize;
let mut intermarker_idx = 0usize;
for &marker in &jpeg.marker_order {
match marker {
0xD9 => {
writer.write_marker(0xD9)?;
}
0xDA => {
if scan_idx >= jpeg.scan_info.len() {
return Err(Error::InvalidJbrd("too many SOS markers".into()));
}
writer.write_sos(jpeg, scan_idx)?;
scan_idx += 1;
}
0xE0..=0xEF => {
if app_idx >= jpeg.app_data.len() {
return Err(Error::InvalidJbrd("too many APP markers".into()));
}
writer.write_app_marker(marker, &jpeg.app_data[app_idx])?;
app_idx += 1;
}
0xFE => {
if com_idx >= jpeg.com_data.len() {
return Err(Error::InvalidJbrd("too many COM markers".into()));
}
writer.write_com_marker(&jpeg.com_data[com_idx])?;
com_idx += 1;
}
0xDB => {
writer.write_dqt(jpeg, &mut dqt_idx)?;
}
0xC4 => {
writer.write_dht(jpeg, &mut dht_idx)?;
}
0xC0 => {
writer.write_sof0(jpeg)?;
}
0xDD => {
writer.write_dri(jpeg.restart_interval)?;
}
0xFF => {
if intermarker_idx >= jpeg.inter_marker_data.len() {
return Err(Error::InvalidJbrd("too many inter-marker data".into()));
}
writer.write_intermarker_data(&jpeg.inter_marker_data[intermarker_idx]);
intermarker_idx += 1;
}
_ => {
}
}
}
}
out.extend_from_slice(&jpeg.tail_data);
Ok(out)
}
struct JpegWriter<'a> {
out: &'a mut Vec<u8>,
}
impl<'a> JpegWriter<'a> {
fn new(out: &'a mut Vec<u8>) -> Self {
Self { out }
}
fn write_intermarker_data(&mut self, data: &[u8]) {
self.out.extend_from_slice(data);
}
fn write_marker(&mut self, marker: u8) -> Result<()> {
self.out.push(0xFF);
self.out.push(marker);
Ok(())
}
fn write_app_marker(&mut self, marker: u8, data: &[u8]) -> Result<()> {
self.out.push(0xFF);
self.out.push(marker);
let len = (data.len() + 2) as u16;
self.out.extend_from_slice(&len.to_be_bytes());
self.out.extend_from_slice(data);
Ok(())
}
fn write_com_marker(&mut self, data: &[u8]) -> Result<()> {
self.out.push(0xFF);
self.out.push(0xFE);
let len = (data.len() + 2) as u16;
self.out.extend_from_slice(&len.to_be_bytes());
self.out.extend_from_slice(data);
Ok(())
}
fn write_dri(&mut self, restart_interval: u32) -> Result<()> {
self.out.push(0xFF);
self.out.push(0xDD);
self.out.extend_from_slice(&4u16.to_be_bytes()); self.out
.extend_from_slice(&(restart_interval as u16).to_be_bytes());
Ok(())
}
fn write_dqt(&mut self, jpeg: &JpegData, idx: &mut usize) -> Result<()> {
self.out.push(0xFF);
self.out.push(0xDB);
let start = *idx;
let mut total_payload = 0usize;
loop {
if *idx >= jpeg.quant.len() {
return Err(Error::InvalidJbrd("too many DQT tables".into()));
}
let qt = &jpeg.quant[*idx];
let precision_bytes = if qt.precision == 0 { 1 } else { 2 };
total_payload += 1 + 64 * precision_bytes; let is_last = qt.is_last;
*idx += 1;
if is_last {
break;
}
}
let length = (total_payload + 2) as u16;
self.out.extend_from_slice(&length.to_be_bytes());
for i in start..*idx {
let qt = &jpeg.quant[i];
let pq_tq = ((qt.precision as u8) << 4) | (qt.index as u8);
self.out.push(pq_tq);
if qt.precision == 0 {
for &zi in &ZIGZAG {
self.out.push(qt.values[zi] as u8);
}
} else {
for &zi in &ZIGZAG {
self.out
.extend_from_slice(&(qt.values[zi] as u16).to_be_bytes());
}
}
}
Ok(())
}
fn write_dht(&mut self, jpeg: &JpegData, idx: &mut usize) -> Result<()> {
self.out.push(0xFF);
self.out.push(0xC4);
let start = *idx;
let mut total_payload = 0usize;
loop {
if *idx >= jpeg.huffman_code.len() {
return Err(Error::InvalidJbrd("too many DHT tables".into()));
}
let hc = &jpeg.huffman_code[*idx];
let num_values: u32 = hc.counts.iter().sum();
total_payload += 1 + 16 + num_values as usize;
let is_last = hc.is_last;
*idx += 1;
if is_last {
break;
}
}
let length = (total_payload + 2) as u16;
self.out.extend_from_slice(&length.to_be_bytes());
for i in start..*idx {
let hc = &jpeg.huffman_code[i];
let tc_th = if hc.is_ac { 0x10 } else { 0x00 } | (hc.id as u8);
self.out.push(tc_th);
for &count in &hc.counts {
self.out.push(count as u8);
}
for &val in &hc.values {
self.out.push(val);
}
}
Ok(())
}
fn write_sof0(&mut self, jpeg: &JpegData) -> Result<()> {
self.out.push(0xFF);
self.out.push(0xC0);
let nc = jpeg.components.len();
let length = (8 + 3 * nc) as u16;
self.out.extend_from_slice(&length.to_be_bytes());
self.out.push(8); self.out
.extend_from_slice(&(jpeg.height as u16).to_be_bytes());
self.out
.extend_from_slice(&(jpeg.width as u16).to_be_bytes());
self.out.push(nc as u8);
for comp in &jpeg.components {
self.out.push(comp.id as u8);
let hv = ((comp.h_samp_factor as u8) << 4) | (comp.v_samp_factor as u8);
self.out.push(hv);
self.out.push(comp.quant_idx as u8);
}
Ok(())
}
fn write_sos(&mut self, jpeg: &JpegData, scan_idx: usize) -> Result<()> {
let scan = &jpeg.scan_info[scan_idx];
self.out.push(0xFF);
self.out.push(0xDA);
let length = (6 + 2 * scan.num_components) as u16;
self.out.extend_from_slice(&length.to_be_bytes());
self.out.push(scan.num_components as u8);
for i in 0..scan.num_components as usize {
let comp_idx = scan.component_indices[i] as usize;
self.out.push(jpeg.components[comp_idx].id as u8);
let td_ta = ((scan.dc_tbl_idx[i] as u8) << 4) | (scan.ac_tbl_idx[i] as u8);
self.out.push(td_ta);
}
self.out.push(scan.ss as u8);
self.out.push(scan.se as u8);
let ah_al = ((scan.ah as u8) << 4) | (scan.al as u8);
self.out.push(ah_al);
self.write_scan_data(jpeg, scan_idx)?;
Ok(())
}
fn write_scan_data(&mut self, jpeg: &JpegData, scan_idx: usize) -> Result<()> {
let scan = &jpeg.scan_info[scan_idx];
let mut dc_tables: [Option<HuffmanEncodeTable>; 4] = [None, None, None, None];
let mut ac_tables: [Option<HuffmanEncodeTable>; 4] = [None, None, None, None];
for hc in &jpeg.huffman_code {
let table = HuffmanEncodeTable::from_counts_values(&hc.counts, &hc.values);
if hc.is_ac {
ac_tables[hc.id as usize] = Some(table);
} else {
dc_tables[hc.id as usize] = Some(table);
}
}
let mut bw = BitWriter::new();
let mut padding_bit_idx = 0usize;
let mut reset_point_idx = 0usize;
let mut extra_zero_idx = 0usize;
let mut dc_pred = vec![0i32; jpeg.components.len()];
let is_interleaved = scan.num_components > 1;
let (mcu_rows, mcu_cols) = if is_interleaved {
let max_h: u32 = jpeg
.components
.iter()
.map(|c| c.h_samp_factor)
.max()
.unwrap_or(1);
let max_v: u32 = jpeg
.components
.iter()
.map(|c| c.v_samp_factor)
.max()
.unwrap_or(1);
let mcu_cols = jpeg.width.div_ceil(max_h * 8);
let mcu_rows = jpeg.height.div_ceil(max_v * 8);
(mcu_rows, mcu_cols)
} else {
let comp_idx = scan.component_indices[0] as usize;
let comp = &jpeg.components[comp_idx];
(comp.height_in_blocks, comp.width_in_blocks)
};
let mut block_count: u32 = 0;
for mcu_row in 0..mcu_rows {
for mcu_col in 0..mcu_cols {
if reset_point_idx < scan.reset_points.len()
&& block_count == scan.reset_points[reset_point_idx]
{
bw.pad_to_byte(&jpeg.padding_bits, &mut padding_bit_idx);
self.out.extend_from_slice(&bw.finish());
bw = BitWriter::new();
let rst_marker = 0xD0 + ((reset_point_idx % 8) as u8);
self.out.push(0xFF);
self.out.push(rst_marker);
dc_pred.fill(0);
reset_point_idx += 1;
}
for sci in 0..scan.num_components as usize {
let comp_idx = scan.component_indices[sci] as usize;
let comp = &jpeg.components[comp_idx];
let dc_table = dc_tables[scan.dc_tbl_idx[sci] as usize]
.as_ref()
.ok_or_else(|| Error::InvalidJbrd("missing DC table".into()))?;
let ac_table = ac_tables[scan.ac_tbl_idx[sci] as usize]
.as_ref()
.ok_or_else(|| Error::InvalidJbrd("missing AC table".into()))?;
let (h_blocks, v_blocks) = if is_interleaved {
(comp.h_samp_factor, comp.v_samp_factor)
} else {
(1, 1)
};
for v in 0..v_blocks {
for h in 0..h_blocks {
let (by, bx) = if is_interleaved {
(
mcu_row * comp.v_samp_factor + v,
mcu_col * comp.h_samp_factor + h,
)
} else {
(mcu_row, mcu_col)
};
if by >= comp.height_in_blocks || bx >= comp.width_in_blocks {
encode_dc(&mut bw, 0, &mut dc_pred[comp_idx], dc_table);
encode_ac_eob(&mut bw, ac_table);
} else {
let block_offset = (by * comp.width_in_blocks + bx) as usize * 64;
let coeffs = &comp.coeffs[block_offset..block_offset + 64];
while extra_zero_idx < scan.extra_zero_runs.len()
&& scan.extra_zero_runs[extra_zero_idx].0 == block_count
{
let num_runs = scan.extra_zero_runs[extra_zero_idx].1;
for _ in 0..num_runs {
bw.write_huffman(ac_table, 0xF0);
}
extra_zero_idx += 1;
}
encode_dc(
&mut bw,
coeffs[0] as i32,
&mut dc_pred[comp_idx],
dc_table,
);
encode_ac(&mut bw, coeffs, ac_table);
}
block_count += 1;
}
}
}
}
}
bw.pad_to_byte(&jpeg.padding_bits, &mut padding_bit_idx);
self.out.extend_from_slice(&bw.finish());
Ok(())
}
}
fn encode_dc(bw: &mut BitWriter, dc: i32, dc_pred: &mut i32, table: &HuffmanEncodeTable) {
let diff = dc - *dc_pred;
*dc_pred = dc;
let (category, extra_bits, extra_len) = categorize(diff);
bw.write_huffman(table, category as u8);
if extra_len > 0 {
bw.write_bits(extra_bits, extra_len);
}
}
fn encode_ac(bw: &mut BitWriter, block: &[i16], table: &HuffmanEncodeTable) {
let mut zero_run = 0u32;
let mut last_nonzero_zi = 0usize; for zi in (1..64).rev() {
if block[ZIGZAG[zi]] != 0 {
last_nonzero_zi = zi;
break;
}
}
if last_nonzero_zi == 0 {
bw.write_huffman(table, 0x00);
return;
}
for zi in 1..=last_nonzero_zi {
let coeff = block[ZIGZAG[zi]];
if coeff == 0 {
zero_run += 1;
continue;
}
while zero_run > 15 {
bw.write_huffman(table, 0xF0); zero_run -= 16;
}
let (category, extra_bits, extra_len) = categorize(coeff as i32);
let symbol = ((zero_run as u8) << 4) | (category as u8);
bw.write_huffman(table, symbol);
if extra_len > 0 {
bw.write_bits(extra_bits, extra_len);
}
zero_run = 0;
}
if last_nonzero_zi < 63 {
bw.write_huffman(table, 0x00);
}
}
fn encode_ac_eob(bw: &mut BitWriter, table: &HuffmanEncodeTable) {
bw.write_huffman(table, 0x00);
}
fn categorize(value: i32) -> (u32, u32, u32) {
if value == 0 {
return (0, 0, 0);
}
let abs_val = value.unsigned_abs();
let category = 32 - abs_val.leading_zeros(); let extra_bits = if value > 0 {
value as u32
} else {
(value + (1 << category) - 1) as u32
};
(category, extra_bits, category)
}
struct HuffmanEncodeTable {
codes: [u32; 256],
lengths: [u8; 256],
}
impl HuffmanEncodeTable {
fn from_counts_values(counts: &[u32; 16], values: &[u8]) -> Self {
let mut codes = [0u32; 256];
let mut lengths = [0u8; 256];
let mut code: u32 = 0;
let mut val_idx = 0;
for (bits_minus_1, &count) in counts.iter().enumerate() {
let bits = bits_minus_1 as u8 + 1;
for _ in 0..count {
if val_idx < values.len() {
let symbol = values[val_idx] as usize;
codes[symbol] = code;
lengths[symbol] = bits;
val_idx += 1;
}
code += 1;
}
code <<= 1;
}
Self { codes, lengths }
}
}
struct BitWriter {
buffer: Vec<u8>,
bit_buffer: u32,
bits_in_buffer: u32,
}
impl BitWriter {
fn new() -> Self {
Self {
buffer: Vec::new(),
bit_buffer: 0,
bits_in_buffer: 0,
}
}
fn write_huffman(&mut self, table: &HuffmanEncodeTable, symbol: u8) {
let code = table.codes[symbol as usize];
let length = table.lengths[symbol as usize];
if length > 0 {
self.write_bits(code, length as u32);
}
}
fn write_bits(&mut self, value: u32, num_bits: u32) {
self.bit_buffer = (self.bit_buffer << num_bits) | (value & ((1 << num_bits) - 1));
self.bits_in_buffer += num_bits;
while self.bits_in_buffer >= 8 {
self.bits_in_buffer -= 8;
let byte = ((self.bit_buffer >> self.bits_in_buffer) & 0xFF) as u8;
self.buffer.push(byte);
if byte == 0xFF {
self.buffer.push(0x00); }
}
}
fn pad_to_byte(&mut self, padding_bits: &[u8], padding_idx: &mut usize) {
while !self.bits_in_buffer.is_multiple_of(8) {
let bit = if *padding_idx < padding_bits.len() {
let b = padding_bits[*padding_idx];
*padding_idx += 1;
b
} else {
1 };
self.write_bits(bit as u32, 1);
}
}
fn finish(mut self) -> Vec<u8> {
debug_assert!(self.bits_in_buffer == 0);
std::mem::take(&mut self.buffer)
}
}