use super::bitio::{BitReader, BitWriter};
use super::dct::DctGrid;
use super::error::{JpegError, Result};
use super::frame::FrameInfo;
use super::huffman::{
encode_value, extend_sign, HuffmanDecodeTable, HuffmanEncodeTable,
};
use super::marker::SosParams;
use super::tables::HuffmanSpec;
use super::zigzag::{NATURAL_TO_ZIGZAG, ZIGZAG_TO_NATURAL};
#[derive(Clone)]
pub struct ScanComponent {
pub comp_idx: usize,
pub dc_table: usize,
pub ac_table: usize,
}
pub fn decode_scan(
data: &[u8],
scan_start: usize,
frame: &FrameInfo,
scan_components: &[ScanComponent],
dc_specs: &[Option<HuffmanSpec>; 4],
ac_specs: &[Option<HuffmanSpec>; 4],
restart_interval: u16,
) -> Result<(Vec<DctGrid>, usize)> {
let mut dc_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
let mut ac_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
for sc in scan_components {
if dc_tables[sc.dc_table].is_none() {
let spec = dc_specs[sc.dc_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
dc_tables[sc.dc_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
}
if ac_tables[sc.ac_table].is_none() {
let spec = ac_specs[sc.ac_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
ac_tables[sc.ac_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
}
}
let mut grids: Vec<DctGrid> = Vec::with_capacity(scan_components.len());
for sc in scan_components {
let bw = frame.blocks_wide(sc.comp_idx);
let bt = frame.blocks_tall(sc.comp_idx);
grids.push(DctGrid::new(bw, bt));
}
let mut dc_pred = vec![0i32; scan_components.len()];
let mut reader = BitReader::new(data, scan_start);
let mut mcu_count = 0usize;
for mcu_row in 0..frame.mcus_tall as usize {
for mcu_col in 0..frame.mcus_wide as usize {
if restart_interval > 0 && mcu_count > 0 && mcu_count.is_multiple_of(restart_interval as usize) {
reader.byte_align();
let _rst = reader.check_restart_marker()?;
dc_pred.fill(0);
}
for (sci, sc) in scan_components.iter().enumerate() {
let comp = &frame.components[sc.comp_idx];
let dc_tab = dc_tables[sc.dc_table].as_ref().unwrap();
let ac_tab = ac_tables[sc.ac_table].as_ref().unwrap();
for v in 0..comp.v_sampling as usize {
for h in 0..comp.h_sampling as usize {
let block_row = mcu_row * (comp.v_sampling as usize) + v;
let block_col = mcu_col * (comp.h_sampling as usize) + h;
let blocks_tall = grids[sci].blocks_tall();
let blocks_wide = grids[sci].blocks_wide();
if block_row >= blocks_tall || block_col >= blocks_wide {
let dc_size = dc_tab.decode(&mut reader)?;
if dc_size > 0 {
let dc_bits = reader.read_bits(dc_size)?;
let dc_diff = extend_sign(dc_bits, dc_size);
dc_pred[sci] += dc_diff as i32;
}
let mut k = 1;
while k < 64 {
let rs = ac_tab.decode(&mut reader)?;
let run = (rs >> 4) as usize;
let size = rs & 0x0F;
if size == 0 {
if run == 0 || run != 15 { break; }
k += 16;
continue;
}
k += run;
if k >= 64 { return Err(JpegError::HuffmanDecode); }
let _ac_bits = reader.read_bits(size)?;
k += 1;
}
continue;
}
let mut zz = [0i16; 64];
let dc_size = dc_tab.decode(&mut reader)?;
if dc_size > 0 {
let dc_bits = reader.read_bits(dc_size)?;
let dc_diff = extend_sign(dc_bits, dc_size);
dc_pred[sci] += dc_diff as i32;
}
zz[0] = dc_pred[sci].clamp(i16::MIN as i32, i16::MAX as i32) as i16;
let mut k = 1;
while k < 64 {
let rs = ac_tab.decode(&mut reader)?;
let run = (rs >> 4) as usize;
let size = rs & 0x0F;
if size == 0 {
if run == 0 {
break;
} else if run == 15 {
k += 16;
continue;
} else {
break;
}
}
k += run;
if k >= 64 {
return Err(JpegError::HuffmanDecode);
}
let ac_bits = reader.read_bits(size)?;
zz[k] = extend_sign(ac_bits, size);
k += 1;
}
let block = grids[sci].block_mut(block_row, block_col);
for zi in 0..64 {
block[ZIGZAG_TO_NATURAL[zi]] = zz[zi];
}
}
}
}
mcu_count += 1;
}
}
let end_pos = reader.position();
Ok((grids, end_pos))
}
pub fn encode_scan(
frame: &FrameInfo,
scan_components: &[ScanComponent],
grids: &[DctGrid],
dc_specs: &[Option<HuffmanSpec>; 4],
ac_specs: &[Option<HuffmanSpec>; 4],
restart_interval: u16,
) -> Result<Vec<u8>> {
encode_scan_with_progress(frame, scan_components, grids, dc_specs, ac_specs, restart_interval, None)
}
pub const JPEG_WRITE_STEPS: u32 = 20;
pub fn encode_scan_with_progress(
frame: &FrameInfo,
scan_components: &[ScanComponent],
grids: &[DctGrid],
dc_specs: &[Option<HuffmanSpec>; 4],
ac_specs: &[Option<HuffmanSpec>; 4],
restart_interval: u16,
on_progress: Option<&dyn Fn()>,
) -> Result<Vec<u8>> {
let mut dc_tables: [Option<HuffmanEncodeTable>; 4] = [None, None, None, None];
let mut ac_tables: [Option<HuffmanEncodeTable>; 4] = [None, None, None, None];
for sc in scan_components {
if dc_tables[sc.dc_table].is_none() {
let spec = dc_specs[sc.dc_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
dc_tables[sc.dc_table] = Some(HuffmanEncodeTable::build(&spec.bits, &spec.huffval));
}
if ac_tables[sc.ac_table].is_none() {
let spec = ac_specs[sc.ac_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
ac_tables[sc.ac_table] = Some(HuffmanEncodeTable::build(&spec.bits, &spec.huffval));
}
}
let mut output = Vec::new();
let mut writer = BitWriter::new();
let mut dc_pred = vec![0i32; scan_components.len()];
let mut mcu_count = 0usize;
let mut restart_count = 0u16;
let mcus_tall = frame.mcus_tall as usize;
let row_interval = if mcus_tall > 0 { (mcus_tall / JPEG_WRITE_STEPS as usize).max(1) } else { 1 };
for mcu_row in 0..mcus_tall {
if let Some(ref cb) = on_progress
&& mcu_row > 0 && mcu_row % row_interval == 0 {
cb();
}
for mcu_col in 0..frame.mcus_wide as usize {
if restart_interval > 0 && mcu_count > 0 && mcu_count.is_multiple_of(restart_interval as usize) {
let segment = std::mem::take(&mut writer).flush();
output.extend_from_slice(&segment);
let rst_marker = 0xD0 + (restart_count % 8) as u8;
output.push(0xFF);
output.push(rst_marker);
restart_count += 1;
dc_pred.fill(0);
}
for (sci, sc) in scan_components.iter().enumerate() {
let comp = &frame.components[sc.comp_idx];
let dc_tab = dc_tables[sc.dc_table].as_ref().unwrap();
let ac_tab = ac_tables[sc.ac_table].as_ref().unwrap();
for v in 0..comp.v_sampling as usize {
for h in 0..comp.h_sampling as usize {
let block_row = mcu_row * (comp.v_sampling as usize) + v;
let block_col = mcu_col * (comp.h_sampling as usize) + h;
let block = grids[sci].block(block_row, block_col);
let mut zz = [0i16; 64];
for ni in 0..64 {
zz[NATURAL_TO_ZIGZAG[ni]] = block[ni];
}
let dc_diff = (zz[0] as i32 - dc_pred[sci]) as i16;
dc_pred[sci] = zz[0] as i32;
let (dc_bits, dc_size) = encode_value(dc_diff);
let (dc_code, dc_code_len) = dc_tab.encode(dc_size)?;
writer.write_bits(dc_code, dc_code_len);
if dc_size > 0 {
writer.write_bits(dc_bits, dc_size);
}
let mut k = 1;
while k < 64 {
let mut run = 0usize;
while k + run < 64 && zz[k + run] == 0 {
run += 1;
}
if k + run >= 64 {
let (eob_code, eob_len) = ac_tab.encode(0x00)?;
writer.write_bits(eob_code, eob_len);
break;
}
while run >= 16 {
let (zrl_code, zrl_len) = ac_tab.encode(0xF0)?;
writer.write_bits(zrl_code, zrl_len);
run -= 16;
k += 16;
}
k += run;
let (ac_bits, ac_size) = encode_value(zz[k]);
let rs = ((run as u8) << 4) | ac_size;
let (ac_code, ac_code_len) = ac_tab.encode(rs)?;
writer.write_bits(ac_code, ac_code_len);
if ac_size > 0 {
writer.write_bits(ac_bits, ac_size);
}
k += 1;
}
}
}
}
mcu_count += 1;
}
}
output.extend_from_slice(&writer.flush());
Ok(output)
}
#[allow(unused_assignments)]
pub fn decode_progressive_scan(
data: &[u8],
scan_start: usize,
frame: &FrameInfo,
scan_components: &[ScanComponent],
dc_specs: &[Option<HuffmanSpec>; 4],
ac_specs: &[Option<HuffmanSpec>; 4],
restart_interval: u16,
params: &SosParams,
grids: &mut [DctGrid],
) -> Result<usize> {
let ss = params.ss as usize;
let se = params.se as usize;
let ah = params.ah;
let al = params.al;
if ss > 63 || se > 63 || ss > se {
return Err(JpegError::InvalidMarkerData("invalid spectral selection"));
}
let mut dc_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
let mut ac_tables: [Option<HuffmanDecodeTable>; 4] = [None, None, None, None];
for sc in scan_components {
if ss == 0 && ah == 0 && dc_tables[sc.dc_table].is_none() {
let spec = dc_specs[sc.dc_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
dc_tables[sc.dc_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
}
if ss > 0 && ac_tables[sc.ac_table].is_none() {
let spec = ac_specs[sc.ac_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
ac_tables[sc.ac_table] = Some(HuffmanDecodeTable::build(&spec.bits, &spec.huffval)?);
}
}
let mut reader = BitReader::new(data, scan_start);
let mut dc_pred = vec![0i32; scan_components.len()];
let mut mcu_count = 0usize;
let mut eob_run: u32 = 0;
let non_interleaved = scan_components.len() == 1 && (ss > 0 || se > 0 || frame.components.len() == 1);
if non_interleaved && scan_components.len() == 1 {
let sc = &scan_components[0];
let bw = frame.blocks_wide(sc.comp_idx);
let bt = frame.blocks_tall(sc.comp_idx);
let mut block_count = 0usize;
for block_row in 0..bt {
for block_col in 0..bw {
if restart_interval > 0 && block_count > 0 && block_count.is_multiple_of(restart_interval as usize) {
reader.byte_align();
let _rst = reader.check_restart_marker()?;
dc_pred[0] = 0;
eob_run = 0;
}
let grid = &mut grids[sc.comp_idx];
if ss == 0 {
if ah == 0 {
decode_dc_first(&mut reader, &dc_tables, sc, &mut dc_pred[0], al, grid, block_row, block_col)?;
} else {
decode_dc_refine(&mut reader, al, grid, block_row, block_col)?;
}
}
if se > 0 {
let ac_start = if ss == 0 { 1 } else { ss };
if ah == 0 {
decode_ac_first(&mut reader, &ac_tables, sc, al, ac_start, se, &mut eob_run, grid, block_row, block_col)?;
} else {
decode_ac_refine(&mut reader, &ac_tables, sc, al, ac_start, se, &mut eob_run, grid, block_row, block_col)?;
}
}
block_count += 1;
}
}
} else {
for mcu_row in 0..frame.mcus_tall as usize {
for mcu_col in 0..frame.mcus_wide as usize {
if restart_interval > 0 && mcu_count > 0 && mcu_count.is_multiple_of(restart_interval as usize) {
reader.byte_align();
let _rst = reader.check_restart_marker()?;
dc_pred.fill(0);
eob_run = 0;
}
for (sci, sc) in scan_components.iter().enumerate() {
let comp = &frame.components[sc.comp_idx];
for v in 0..comp.v_sampling as usize {
for h in 0..comp.h_sampling as usize {
let block_row = mcu_row * (comp.v_sampling as usize) + v;
let block_col = mcu_col * (comp.h_sampling as usize) + h;
let grid = &mut grids[sc.comp_idx];
if ss == 0 {
if ah == 0 {
decode_dc_first(&mut reader, &dc_tables, sc, &mut dc_pred[sci], al, grid, block_row, block_col)?;
} else {
decode_dc_refine(&mut reader, al, grid, block_row, block_col)?;
}
}
}
}
}
mcu_count += 1;
}
}
}
let end_pos = reader.position();
Ok(end_pos)
}
fn decode_dc_first(
reader: &mut BitReader,
dc_tables: &[Option<HuffmanDecodeTable>; 4],
sc: &ScanComponent,
dc_pred: &mut i32,
al: u8,
grid: &mut DctGrid,
block_row: usize,
block_col: usize,
) -> Result<()> {
let dc_tab = dc_tables[sc.dc_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.dc_table as u8))?;
let dc_size = dc_tab.decode(reader)?;
if dc_size > 0 {
let dc_bits = reader.read_bits(dc_size)?;
let dc_diff = extend_sign(dc_bits, dc_size);
*dc_pred += dc_diff as i32;
}
let block = grid.block_mut(block_row, block_col);
block[0] = ((*dc_pred).clamp(i16::MIN as i32, i16::MAX as i32) as i16) << al;
Ok(())
}
fn decode_dc_refine(
reader: &mut BitReader,
al: u8,
grid: &mut DctGrid,
block_row: usize,
block_col: usize,
) -> Result<()> {
let bit = reader.read_bits(1)?;
let block = grid.block_mut(block_row, block_col);
if bit != 0 {
block[0] |= 1i16 << al;
}
Ok(())
}
fn decode_ac_first(
reader: &mut BitReader,
ac_tables: &[Option<HuffmanDecodeTable>; 4],
sc: &ScanComponent,
al: u8,
ss: usize,
se: usize,
eob_run: &mut u32,
grid: &mut DctGrid,
block_row: usize,
block_col: usize,
) -> Result<()> {
let block = grid.block_mut(block_row, block_col);
if *eob_run > 0 {
*eob_run -= 1;
return Ok(());
}
let ac_tab = ac_tables[sc.ac_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
let mut k = ss;
while k <= se {
let rs = ac_tab.decode(reader)?;
let run = (rs >> 4) as usize;
let size = rs & 0x0F;
if size == 0 {
if run == 15 {
k += 16;
continue;
} else {
*eob_run = 1u32 << (run as u32);
if run > 0 {
let extra = reader.read_bits(run as u8)? as u32;
*eob_run += extra;
}
*eob_run -= 1; return Ok(());
}
}
k += run;
if k > se {
return Err(JpegError::HuffmanDecode);
}
let ac_bits = reader.read_bits(size)?;
let val = extend_sign(ac_bits, size);
block[ZIGZAG_TO_NATURAL[k]] = val << al;
k += 1;
}
Ok(())
}
fn decode_ac_refine(
reader: &mut BitReader,
ac_tables: &[Option<HuffmanDecodeTable>; 4],
sc: &ScanComponent,
al: u8,
ss: usize,
se: usize,
eob_run: &mut u32,
grid: &mut DctGrid,
block_row: usize,
block_col: usize,
) -> Result<()> {
let block = grid.block_mut(block_row, block_col);
let p1 = 1i16 << al; let m1 = (-1i16) << al;
let ac_tab = ac_tables[sc.ac_table]
.as_ref()
.ok_or(JpegError::InvalidHuffmanTableId(sc.ac_table as u8))?;
let mut k = ss;
if *eob_run > 0 {
while k <= se {
let ni = ZIGZAG_TO_NATURAL[k];
if block[ni] != 0 {
let bit = reader.read_bits(1)?;
if bit != 0 {
if block[ni] > 0 {
block[ni] += p1;
} else {
block[ni] += m1;
}
}
}
k += 1;
}
*eob_run -= 1;
return Ok(());
}
while k <= se {
let rs = ac_tab.decode(reader)?;
let run = (rs >> 4) as usize; let size = rs & 0x0F;
if size == 0 {
if run == 15 {
let mut zeros_to_skip = 16usize;
while k <= se && zeros_to_skip > 0 {
let ni = ZIGZAG_TO_NATURAL[k];
if block[ni] != 0 {
let bit = reader.read_bits(1)?;
if bit != 0 {
if block[ni] > 0 {
block[ni] += p1;
} else {
block[ni] += m1;
}
}
} else {
zeros_to_skip -= 1;
}
k += 1;
}
continue;
} else {
*eob_run = 1u32 << (run as u32);
if run > 0 {
let extra = reader.read_bits(run as u8)? as u32;
*eob_run += extra;
}
while k <= se {
let ni = ZIGZAG_TO_NATURAL[k];
if block[ni] != 0 {
let bit = reader.read_bits(1)?;
if bit != 0 {
if block[ni] > 0 {
block[ni] += p1;
} else {
block[ni] += m1;
}
}
}
k += 1;
}
*eob_run -= 1;
return Ok(());
}
} else if size == 1 {
let sign_bit = reader.read_bits(1)?;
let new_val = if sign_bit != 0 { p1 } else { m1 };
let mut zeros_to_skip = run;
while k <= se {
let ni = ZIGZAG_TO_NATURAL[k];
if block[ni] != 0 {
let bit = reader.read_bits(1)?;
if bit != 0 {
if block[ni] > 0 {
block[ni] += p1;
} else {
block[ni] += m1;
}
}
} else {
if zeros_to_skip == 0 {
block[ni] = new_val;
k += 1;
break;
}
zeros_to_skip -= 1;
}
k += 1;
}
continue;
} else {
return Err(JpegError::HuffmanDecode);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_value_dc_diff() {
let (bits, size) = encode_value(5);
assert_eq!(size, 3);
assert_eq!(bits, 5);
let (bits, size) = encode_value(-3);
assert_eq!(size, 2);
let recovered = extend_sign(bits, size);
assert_eq!(recovered, -3);
}
}