#![forbid(unsafe_code)]
#[derive(Debug)]
pub enum DctError {
NotJpeg,
Truncated,
CorruptEntropy,
Unsupported(String),
Missing(String),
Incompatible(String),
}
impl core::fmt::Display for DctError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DctError::NotJpeg => f.write_str("not a JPEG file"),
DctError::Truncated => f.write_str("truncated JPEG data"),
DctError::CorruptEntropy => f.write_str("corrupt or malformed JPEG entropy stream"),
DctError::Unsupported(s) => write!(f, "unsupported JPEG variant: {}", s),
DctError::Missing(s) => write!(f, "missing required JPEG structure: {}", s),
DctError::Incompatible(s) => {
write!(f, "coefficient data is incompatible with this JPEG: {}", s)
}
}
}
}
impl std::error::Error for DctError {}
#[derive(Debug, Clone)]
pub struct ComponentInfo {
pub id: u8,
pub h_samp: u8,
pub v_samp: u8,
pub block_count: usize,
}
#[derive(Debug, Clone)]
pub struct JpegInfo {
pub width: u16,
pub height: u16,
pub components: Vec<ComponentInfo>,
}
#[derive(Debug, Clone)]
pub struct ComponentCoefficients {
pub id: u8,
pub blocks: Vec<[i16; 64]>,
}
#[derive(Debug, Clone)]
pub struct JpegCoefficients {
pub components: Vec<ComponentCoefficients>,
}
#[must_use = "returns the decoded coefficients or an error; ignoring it discards the result"]
pub fn read_coefficients(jpeg: &[u8]) -> Result<JpegCoefficients, DctError> {
let mut parser = JpegParser::new(jpeg)?;
parser.parse()?;
parser.decode_coefficients()
}
#[must_use = "returns the re-encoded JPEG bytes or an error; ignoring it discards the result"]
pub fn write_coefficients(jpeg: &[u8], coeffs: &JpegCoefficients) -> Result<Vec<u8>, DctError> {
let mut parser = JpegParser::new(jpeg)?;
parser.parse()?;
parser.encode_coefficients(jpeg, coeffs)
}
#[must_use = "returns block counts or an error; ignoring it discards the result"]
pub fn block_count(jpeg: &[u8]) -> Result<Vec<usize>, DctError> {
let mut parser = JpegParser::new(jpeg)?;
parser.parse()?;
parser.block_counts()
}
#[must_use = "returns image metadata or an error; ignoring it discards the result"]
pub fn inspect(jpeg: &[u8]) -> Result<JpegInfo, DctError> {
let mut parser = JpegParser::new(jpeg)?;
parser.parse()?;
let counts = parser.block_counts()?;
Ok(JpegInfo {
width: parser.image_width,
height: parser.image_height,
components: parser
.frame_components
.iter()
.enumerate()
.map(|(i, fc)| ComponentInfo {
id: fc.id,
h_samp: fc.h_samp,
v_samp: fc.v_samp,
block_count: counts[i],
})
.collect(),
})
}
#[must_use = "returns the eligible AC coefficient count or an error; ignoring it discards the result"]
pub fn eligible_ac_count(jpeg: &[u8]) -> Result<usize, DctError> {
Ok(read_coefficients(jpeg)?.eligible_ac_count())
}
impl JpegCoefficients {
#[must_use]
pub fn eligible_ac_count(&self) -> usize {
self.components
.iter()
.flat_map(|c| c.blocks.iter())
.flat_map(|b| b[1..].iter())
.filter(|&&v| v.abs() >= 2)
.count()
}
}
#[rustfmt::skip]
const ZIGZAG: [u8; 64] = [
0, 1, 8, 16, 9, 2, 3, 10,
17, 24, 32, 25, 18, 11, 4, 5,
12, 19, 26, 33, 40, 48, 41, 34,
27, 20, 13, 6, 7, 14, 21, 28,
35, 42, 49, 56, 57, 50, 43, 36,
29, 22, 15, 23, 30, 37, 44, 51,
58, 59, 52, 45, 38, 31, 39, 46,
53, 60, 61, 54, 47, 55, 62, 63,
];
const MAX_MCU_COUNT: usize = 1_048_576;
#[inline]
fn category(value: i16) -> u8 {
if value == 0 {
return 0;
}
let abs = value.unsigned_abs();
let cat = (16u32 - abs.leading_zeros()) as u8;
cat.min(15)
}
#[inline]
fn encode_value(value: i16) -> (u8, u16, u8) {
let cat = category(value);
if cat == 0 {
return (0, 0, 0);
}
let bits = if value > 0 {
value as u16
} else {
let v = (1i16 << cat) - 1 + value;
v as u16
};
(cat, bits, cat)
}
#[derive(Clone)]
struct HuffTable {
lut: Vec<u16>,
encode: [(u16, u8); 256],
}
impl std::fmt::Debug for HuffTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let entries = self.encode.iter().filter(|e| e.1 > 0).count();
f.debug_struct("HuffTable")
.field("encode_entries", &entries)
.finish()
}
}
impl HuffTable {
fn from_jpeg(counts: &[u8; 16], symbols: &[u8]) -> Result<Self, DctError> {
let mut encode = [(0u16, 0u8); 256];
let mut lut = vec![0u16; 65536];
let mut code: u16 = 0;
let mut sym_idx = 0usize;
for len in 1u8..=16u8 {
let count = counts[(len - 1) as usize] as usize;
for _ in 0..count {
if sym_idx >= symbols.len() {
return Err(DctError::CorruptEntropy);
}
if (code as u32) >= (1u32 << len) {
return Err(DctError::CorruptEntropy);
}
let sym = symbols[sym_idx];
sym_idx += 1;
encode[sym as usize] = (code, len);
let spread = 1usize << (16 - len);
let base = (code as usize) << (16 - len);
let entry = ((sym as u16) << 8) | (len as u16);
lut[base..base + spread].fill(entry);
code += 1;
}
code <<= 1;
}
Ok(HuffTable { lut, encode })
}
}
struct BitReader<'a> {
data: &'a [u8],
pos: usize,
buf: u64,
bits: u8,
}
impl<'a> BitReader<'a> {
fn new(data: &'a [u8]) -> Self {
BitReader {
data,
pos: 0,
buf: 0,
bits: 0,
}
}
fn refill(&mut self) {
while self.bits <= 56 {
if self.pos >= self.data.len() {
break;
}
let byte = self.data[self.pos];
if byte == 0xFF {
if self.pos + 1 >= self.data.len() {
break;
}
let next = self.data[self.pos + 1];
if next == 0x00 {
self.pos += 2;
self.buf = (self.buf << 8) | 0xFF;
self.bits += 8;
} else {
break;
}
} else {
self.pos += 1;
self.buf = (self.buf << 8) | (byte as u64);
self.bits += 8;
}
}
}
fn peek(&mut self, n: u8) -> Result<u16, DctError> {
if self.bits < n {
self.refill();
}
if self.bits < n {
return Err(DctError::Truncated);
}
Ok(((self.buf >> (self.bits - n)) & ((1u64 << n) - 1)) as u16)
}
fn consume(&mut self, n: u8) {
debug_assert!(self.bits >= n);
self.bits -= n;
self.buf &= (1u64 << self.bits) - 1;
}
fn read_bits(&mut self, n: u8) -> Result<u16, DctError> {
if n == 0 {
return Ok(0);
}
let v = self.peek(n)?;
self.consume(n);
Ok(v)
}
fn decode_huffman(&mut self, table: &HuffTable) -> Result<u8, DctError> {
if self.bits < 16 {
self.refill();
}
let key = if self.bits >= 16 {
((self.buf >> (self.bits - 16)) & 0xFFFF) as u16
} else {
((self.buf << (16 - self.bits)) & 0xFFFF) as u16
};
let entry = table.lut[key as usize];
let len = (entry & 0xFF) as u8;
let sym = (entry >> 8) as u8;
if len == 0 {
return Err(DctError::CorruptEntropy);
}
if self.bits < len {
return Err(DctError::Truncated);
}
self.consume(len);
Ok(sym)
}
fn sync_restart(&mut self) -> bool {
self.bits = 0;
self.buf = 0;
if self.pos + 1 < self.data.len()
&& self.data[self.pos] == 0xFF
&& (0xD0..=0xD7).contains(&self.data[self.pos + 1])
{
self.pos += 2;
return true;
}
false
}
}
struct BitWriter {
out: Vec<u8>,
buf: u64,
bits: u8,
}
impl BitWriter {
fn with_capacity(cap: usize) -> Self {
BitWriter {
out: Vec::with_capacity(cap),
buf: 0,
bits: 0,
}
}
fn write_bits(&mut self, value: u16, n: u8) {
if n == 0 {
return;
}
self.buf = (self.buf << n) | (value as u64);
self.bits += n;
while self.bits >= 8 {
self.bits -= 8;
let byte = ((self.buf >> self.bits) & 0xFF) as u8;
self.out.push(byte);
if byte == 0xFF {
self.out.push(0x00); }
self.buf &= (1u64 << self.bits) - 1;
}
}
fn flush(&mut self) {
if self.bits > 0 {
let pad = 8 - self.bits;
let byte = (((self.buf << pad) | ((1u64 << pad) - 1)) & 0xFF) as u8;
self.out.push(byte);
if byte == 0xFF {
self.out.push(0x00);
}
self.bits = 0;
self.buf = 0;
}
}
fn write_restart_marker(&mut self, n: u8) {
self.flush();
self.out.push(0xFF);
self.out.push(0xD0 | (n & 0x07));
}
}
#[derive(Debug, Clone)]
struct FrameComponent {
id: u8,
h_samp: u8,
v_samp: u8,
#[allow(dead_code)]
qt_id: u8,
}
#[derive(Debug, Clone)]
struct ScanComponent {
comp_idx: usize, dc_table: usize,
ac_table: usize,
}
struct JpegParser<'a> {
data: &'a [u8],
pos: usize,
entropy_start: usize,
entropy_len: usize,
frame_components: Vec<FrameComponent>,
scan_components: Vec<ScanComponent>,
dc_tables: [Option<HuffTable>; 4],
ac_tables: [Option<HuffTable>; 4],
restart_interval: u16,
image_width: u16,
image_height: u16,
}
impl<'a> JpegParser<'a> {
fn new(data: &'a [u8]) -> Result<Self, DctError> {
if data.len() < 2 || data[0] != 0xFF || data[1] != 0xD8 {
return Err(DctError::NotJpeg);
}
Ok(JpegParser {
data,
pos: 2,
entropy_start: 0,
entropy_len: 0,
frame_components: Vec::new(),
scan_components: Vec::new(),
dc_tables: [None, None, None, None],
ac_tables: [None, None, None, None],
restart_interval: 0,
image_width: 0,
image_height: 0,
})
}
fn read_u16(&mut self) -> Result<u16, DctError> {
if self.pos + 1 >= self.data.len() {
return Err(DctError::Truncated);
}
let v = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
self.pos += 2;
Ok(v)
}
fn parse(&mut self) -> Result<(), DctError> {
loop {
if self.pos >= self.data.len() {
return Err(DctError::Missing("SOS marker".into()));
}
if self.data[self.pos] != 0xFF {
return Err(DctError::CorruptEntropy);
}
while self.pos < self.data.len() && self.data[self.pos] == 0xFF {
self.pos += 1;
}
if self.pos >= self.data.len() {
return Err(DctError::Truncated);
}
let marker = self.data[self.pos];
self.pos += 1;
match marker {
0xD8 => {} 0xD9 => return Err(DctError::Missing("SOS before EOI".into())),
0xC0 | 0xC1 => self.parse_sof()?,
0xC2 => return Err(DctError::Unsupported("progressive JPEG (SOF2)".into())),
0xC3 => return Err(DctError::Unsupported("lossless JPEG (SOF3)".into())),
0xC9 => return Err(DctError::Unsupported("arithmetic coding (SOF9)".into())),
0xCA => {
return Err(DctError::Unsupported(
"progressive arithmetic (SOF10)".into(),
))
}
0xCB => return Err(DctError::Unsupported("lossless arithmetic (SOF11)".into())),
0xC4 => self.parse_dht()?,
0xDD => self.parse_dri()?,
0xDA => {
self.parse_sos_header()?;
self.entropy_start = self.pos;
self.entropy_len = self.find_entropy_end();
return Ok(());
}
_ => {
let len = self.read_u16()? as usize;
if len < 2 {
return Err(DctError::CorruptEntropy);
}
let skip = len - 2;
if self.pos + skip > self.data.len() {
return Err(DctError::Truncated);
}
self.pos += skip;
}
}
}
}
fn parse_sof(&mut self) -> Result<(), DctError> {
let len = self.read_u16()? as usize;
if len < 8 {
return Err(DctError::CorruptEntropy);
}
let end = self.pos + len - 2;
if end > self.data.len() {
return Err(DctError::Truncated);
}
let _precision = self.data[self.pos];
self.pos += 1;
self.image_height = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
self.pos += 2;
self.image_width = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]);
self.pos += 2;
if self.image_width == 0 || self.image_height == 0 {
return Err(DctError::Unsupported("zero image dimension".into()));
}
let ncomp = self.data[self.pos] as usize;
self.pos += 1;
if ncomp == 0 || ncomp > 4 {
return Err(DctError::Unsupported(format!("{} components", ncomp)));
}
if self.pos + ncomp * 3 > end {
return Err(DctError::Truncated);
}
self.frame_components.clear();
for _ in 0..ncomp {
let id = self.data[self.pos];
let samp = self.data[self.pos + 1];
let qt_id = self.data[self.pos + 2];
self.pos += 3;
let h_samp = samp >> 4;
let v_samp = samp & 0x0F;
if h_samp == 0 || v_samp == 0 {
return Err(DctError::CorruptEntropy);
}
self.frame_components.push(FrameComponent {
id,
h_samp,
v_samp,
qt_id,
});
}
self.pos = end;
Ok(())
}
fn parse_dht(&mut self) -> Result<(), DctError> {
let len = self.read_u16()? as usize;
if len < 2 {
return Err(DctError::CorruptEntropy);
}
let end = self.pos + len - 2;
if end > self.data.len() {
return Err(DctError::Truncated);
}
while self.pos < end {
if self.pos >= self.data.len() {
return Err(DctError::Truncated);
}
let tc_th = self.data[self.pos];
self.pos += 1;
let tc = (tc_th >> 4) & 0x0F; let th = (tc_th & 0x0F) as usize;
if tc > 1 {
return Err(DctError::CorruptEntropy);
}
if th > 3 {
return Err(DctError::CorruptEntropy);
}
if self.pos + 16 > end {
return Err(DctError::Truncated);
}
let mut counts = [0u8; 16];
counts.copy_from_slice(&self.data[self.pos..self.pos + 16]);
self.pos += 16;
let total: usize = counts.iter().map(|&c| c as usize).sum();
if total > 256 {
return Err(DctError::CorruptEntropy);
}
if self.pos + total > end {
return Err(DctError::Truncated);
}
let symbols = &self.data[self.pos..self.pos + total];
self.pos += total;
let table = HuffTable::from_jpeg(&counts, symbols)?;
if tc == 0 {
self.dc_tables[th] = Some(table);
} else {
self.ac_tables[th] = Some(table);
}
}
self.pos = end;
Ok(())
}
fn parse_dri(&mut self) -> Result<(), DctError> {
let len = self.read_u16()?;
if len != 4 {
return Err(DctError::CorruptEntropy);
}
self.restart_interval = self.read_u16()?;
Ok(())
}
fn parse_sos_header(&mut self) -> Result<(), DctError> {
let len = self.read_u16()? as usize;
if len < 3 {
return Err(DctError::CorruptEntropy);
}
let end = self.pos + len - 2;
if end > self.data.len() {
return Err(DctError::Truncated);
}
let ns = self.data[self.pos] as usize;
self.pos += 1;
if ns == 0 || ns > self.frame_components.len() {
return Err(DctError::CorruptEntropy);
}
if self.pos + ns * 2 > end {
return Err(DctError::Truncated);
}
self.scan_components.clear();
for _ in 0..ns {
let comp_id = self.data[self.pos];
let td_ta = self.data[self.pos + 1];
self.pos += 2;
let dc_table = (td_ta >> 4) as usize;
let ac_table = (td_ta & 0x0F) as usize;
if dc_table > 3 || ac_table > 3 {
return Err(DctError::CorruptEntropy);
}
let comp_idx = self
.frame_components
.iter()
.position(|fc| fc.id == comp_id)
.ok_or_else(|| DctError::Missing(format!("component id {} in frame", comp_id)))?;
self.scan_components.push(ScanComponent {
comp_idx,
dc_table,
ac_table,
});
}
self.pos = end;
Ok(())
}
fn find_entropy_end(&self) -> usize {
let mut i = self.entropy_start;
while i < self.data.len() {
if self.data[i] == 0xFF && i + 1 < self.data.len() {
let next = self.data[i + 1];
if next == 0x00 {
i += 2;
continue;
}
if (0xD0..=0xD7).contains(&next) {
i += 2;
continue;
}
return i - self.entropy_start;
}
i += 1;
}
self.data.len() - self.entropy_start
}
fn max_h_samp(&self) -> u8 {
self.frame_components
.iter()
.map(|c| c.h_samp)
.max()
.unwrap_or(1)
}
fn max_v_samp(&self) -> u8 {
self.frame_components
.iter()
.map(|c| c.v_samp)
.max()
.unwrap_or(1)
}
fn mcu_cols(&self) -> usize {
let max_h = self.max_h_samp() as usize;
(self.image_width as usize + max_h * 8 - 1) / (max_h * 8)
}
fn mcu_rows(&self) -> usize {
let max_v = self.max_v_samp() as usize;
(self.image_height as usize + max_v * 8 - 1) / (max_v * 8)
}
fn mcu_count(&self) -> Result<usize, DctError> {
self.mcu_cols()
.checked_mul(self.mcu_rows())
.ok_or_else(|| DctError::Unsupported("image dimensions overflow usize".into()))
}
fn du_per_mcu(&self) -> Vec<usize> {
self.scan_components
.iter()
.map(|sc| {
let fc = &self.frame_components[sc.comp_idx];
(fc.h_samp as usize) * (fc.v_samp as usize)
})
.collect()
}
fn block_counts(&self) -> Result<Vec<usize>, DctError> {
let n_mcu = self.mcu_count()?;
let du = self.du_per_mcu();
let mut counts = vec![0usize; self.frame_components.len()];
for (sc_idx, sc) in self.scan_components.iter().enumerate() {
counts[sc.comp_idx] = n_mcu * du[sc_idx];
}
Ok(counts)
}
fn decode_coefficients(&self) -> Result<JpegCoefficients, DctError> {
let entropy = &self.data[self.entropy_start..self.entropy_start + self.entropy_len];
let n_mcu = self.mcu_count()?;
if n_mcu > MAX_MCU_COUNT {
return Err(DctError::Unsupported(format!(
"image too large ({} MCUs; max {})",
n_mcu, MAX_MCU_COUNT
)));
}
let du = self.du_per_mcu();
let counts = self.block_counts()?;
let mut comp_blocks: Vec<Vec<[i16; 64]>> =
counts.iter().map(|&c| vec![[0i16; 64]; c]).collect();
let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
let mut reader = BitReader::new(entropy);
let restart_interval = self.restart_interval as usize;
for mcu_idx in 0..n_mcu {
if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
reader.sync_restart();
for p in dc_pred.iter_mut() {
*p = 0;
}
}
for (sc_idx, sc) in self.scan_components.iter().enumerate() {
let dc_table = self.dc_tables[sc.dc_table]
.as_ref()
.ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
let ac_table = self.ac_tables[sc.ac_table]
.as_ref()
.ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
for _du_i in 0..du[sc_idx] {
let mut block = [0i16; 64];
let dc_cat = reader.decode_huffman(dc_table)?;
let dc_cat = dc_cat.min(15);
let dc_bits = reader.read_bits(dc_cat)?;
let dc_diff = decode_magnitude(dc_cat, dc_bits);
dc_pred[sc_idx] = dc_pred[sc_idx].saturating_add(dc_diff);
block[ZIGZAG[0] as usize] = dc_pred[sc_idx];
let mut k = 1usize;
while k < 64 {
let rs = reader.decode_huffman(ac_table)?;
if rs == 0x00 {
break;
}
if rs == 0xF0 {
k += 16;
continue;
}
let run = (rs >> 4) as usize;
let cat = (rs & 0x0F).min(15);
k += run;
if k >= 64 {
break;
}
let bits = reader.read_bits(cat)?;
let val = decode_magnitude(cat, bits);
block[ZIGZAG[k] as usize] = val;
k += 1;
}
let block_idx = comp_block_idx[sc.comp_idx];
if block_idx >= comp_blocks[sc.comp_idx].len() {
return Err(DctError::CorruptEntropy);
}
comp_blocks[sc.comp_idx][block_idx] = block;
comp_block_idx[sc.comp_idx] += 1;
}
}
}
let components = self
.frame_components
.iter()
.zip(comp_blocks)
.map(|(fc, blocks)| ComponentCoefficients { id: fc.id, blocks })
.collect();
Ok(JpegCoefficients { components })
}
fn encode_coefficients(
&self,
original: &[u8],
coeffs: &JpegCoefficients,
) -> Result<Vec<u8>, DctError> {
if coeffs.components.len() != self.frame_components.len() {
return Err(DctError::Incompatible(format!(
"expected {} components, got {}",
self.frame_components.len(),
coeffs.components.len()
)));
}
let counts = self.block_counts()?;
for (i, (cc, &expected)) in coeffs.components.iter().zip(counts.iter()).enumerate() {
if cc.id != self.frame_components[i].id {
return Err(DctError::Incompatible(format!(
"component {}: expected id {}, got {}",
i, self.frame_components[i].id, cc.id
)));
}
if cc.blocks.len() != expected {
return Err(DctError::Incompatible(format!(
"component {}: expected {} blocks, got {}",
i,
expected,
cc.blocks.len()
)));
}
}
let n_mcu = self.mcu_count()?;
let du = self.du_per_mcu();
let mut writer = BitWriter::with_capacity(self.entropy_len);
let mut dc_pred: Vec<i16> = vec![0; self.scan_components.len()];
let mut comp_block_idx: Vec<usize> = vec![0; self.frame_components.len()];
let restart_interval = self.restart_interval as usize;
let mut rst_count: u8 = 0;
for mcu_idx in 0..n_mcu {
if restart_interval > 0 && mcu_idx > 0 && mcu_idx % restart_interval == 0 {
writer.write_restart_marker(rst_count);
rst_count = rst_count.wrapping_add(1) & 0x07;
for p in dc_pred.iter_mut() {
*p = 0;
}
}
for (sc_idx, sc) in self.scan_components.iter().enumerate() {
let dc_table = self.dc_tables[sc.dc_table]
.as_ref()
.ok_or_else(|| DctError::Missing(format!("DC table {}", sc.dc_table)))?;
let ac_table = self.ac_tables[sc.ac_table]
.as_ref()
.ok_or_else(|| DctError::Missing(format!("AC table {}", sc.ac_table)))?;
for _du_i in 0..du[sc_idx] {
let block = &coeffs.components[sc.comp_idx].blocks[comp_block_idx[sc.comp_idx]];
comp_block_idx[sc.comp_idx] += 1;
let dc_val = block[ZIGZAG[0] as usize];
let dc_diff = dc_val.saturating_sub(dc_pred[sc_idx]);
dc_pred[sc_idx] = dc_val;
let (dc_cat, dc_bits, dc_n) = encode_value(dc_diff);
let (dc_code, dc_code_len) = {
let e = dc_table.encode[dc_cat as usize];
if e.1 == 0 {
return Err(DctError::CorruptEntropy);
}
e
};
writer.write_bits(dc_code, dc_code_len);
writer.write_bits(dc_bits, dc_n);
let last_nonzero_zz = (1..64).rev().find(|&i| block[ZIGZAG[i] as usize] != 0);
let mut k = 1usize;
let mut zero_run = 0usize;
if let Some(last_pos) = last_nonzero_zz {
while k <= last_pos {
let val = block[ZIGZAG[k] as usize];
if val == 0 {
zero_run += 1;
if zero_run == 16 {
let (zrl_code, zrl_len) = {
let e = ac_table.encode[0xF0];
if e.1 == 0 {
return Err(DctError::CorruptEntropy);
}
e
};
writer.write_bits(zrl_code, zrl_len);
zero_run = 0;
}
} else {
let (cat, bits, n) = encode_value(val);
let rs = ((zero_run as u8) << 4) | cat;
let (ac_code, ac_len) = {
let e = ac_table.encode[rs as usize];
if e.1 == 0 {
return Err(DctError::CorruptEntropy);
}
e
};
writer.write_bits(ac_code, ac_len);
writer.write_bits(bits, n);
zero_run = 0;
}
k += 1;
}
}
let needs_eob = last_nonzero_zz.map_or(true, |p| p < 63);
if needs_eob {
let (eob_code, eob_len) = {
let e = ac_table.encode[0x00];
if e.1 == 0 {
return Err(DctError::CorruptEntropy);
}
e
};
writer.write_bits(eob_code, eob_len);
}
}
}
}
writer.flush();
let after_entropy = self.entropy_start + self.entropy_len;
let mut out = Vec::with_capacity(original.len());
out.extend_from_slice(&original[..self.entropy_start]);
out.extend_from_slice(&writer.out);
out.extend_from_slice(&original[after_entropy..]);
Ok(out)
}
}
fn decode_magnitude(cat: u8, bits: u16) -> i16 {
if cat == 0 {
return 0;
}
if bits >= (1u16 << (cat - 1)) {
bits as i16
} else {
bits as i16 - (1i16 << cat) + 1
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_jpeg_gray(width: u32, height: u32) -> Vec<u8> {
use image::{codecs::jpeg::JpegEncoder, GrayImage, ImageEncoder};
let img = GrayImage::from_fn(width, height, |x, y| {
image::Luma([(((x * 7 + y * 13) % 200) + 28) as u8])
});
let mut buf = Vec::new();
let enc = JpegEncoder::new_with_quality(&mut buf, 90);
enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::L8)
.unwrap();
buf
}
fn make_jpeg_rgb(width: u32, height: u32) -> Vec<u8> {
use image::{codecs::jpeg::JpegEncoder, ImageEncoder, RgbImage};
let img = RgbImage::from_fn(width, height, |x, y| {
image::Rgb([
((x * 11 + y * 3) % 200 + 28) as u8,
((x * 5 + y * 17) % 200 + 28) as u8,
((x * 3 + y * 7) % 200 + 28) as u8,
])
});
let mut buf = Vec::new();
let enc = JpegEncoder::new_with_quality(&mut buf, 85);
enc.write_image(img.as_raw(), width, height, image::ExtendedColorType::Rgb8)
.unwrap();
buf
}
#[test]
fn not_jpeg_returns_error() {
let result = read_coefficients(b"PNG\x00garbage");
assert!(matches!(result, Err(DctError::NotJpeg)));
}
#[test]
fn empty_input_returns_error() {
assert!(matches!(read_coefficients(b""), Err(DctError::NotJpeg)));
}
#[test]
fn truncated_returns_error() {
assert!(matches!(
read_coefficients(b"\xFF\xD8\xFF"),
Err(DctError::Truncated | DctError::Missing(_))
));
}
#[test]
fn progressive_jpeg_returns_unsupported() {
let mut data = vec![0xFF, 0xD8]; data.extend_from_slice(&[0xFF, 0xE0, 0x00, 0x10]);
data.extend_from_slice(&[
0x4A, 0x46, 0x49, 0x46, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00,
]);
data.extend_from_slice(&[0xFF, 0xC2, 0x00, 0x0B]);
data.extend_from_slice(&[0x08, 0x00, 0x10, 0x00, 0x10, 0x01, 0x01, 0x11, 0x00]);
let result = read_coefficients(&data);
assert!(matches!(result, Err(DctError::Unsupported(_))));
}
#[test]
fn incompatible_block_count_returns_error() {
let jpeg = make_jpeg_gray(16, 16);
let mut coeffs = read_coefficients(&jpeg).unwrap();
coeffs.components[0].blocks.pop();
let result = write_coefficients(&jpeg, &coeffs);
assert!(matches!(result, Err(DctError::Incompatible(_))));
}
#[test]
fn roundtrip_identity_gray() {
let jpeg = make_jpeg_gray(32, 32);
let coeffs = read_coefficients(&jpeg).unwrap();
let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
}
#[test]
fn roundtrip_identity_rgb() {
let jpeg = make_jpeg_rgb(32, 32);
let coeffs = read_coefficients(&jpeg).unwrap();
let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
assert_eq!(jpeg, reencoded, "roundtrip changed the JPEG bytes");
}
#[test]
fn roundtrip_identity_non_square() {
let jpeg = make_jpeg_rgb(48, 16);
let coeffs = read_coefficients(&jpeg).unwrap();
let reencoded = write_coefficients(&jpeg, &coeffs).unwrap();
assert_eq!(jpeg, reencoded);
}
#[test]
fn lsb_modification_survives_roundtrip() {
let jpeg = make_jpeg_gray(32, 32);
let mut coeffs = read_coefficients(&jpeg).unwrap();
let mut modified_count = 0usize;
for block in &mut coeffs.components[0].blocks {
for coeff in block[1..].iter_mut() {
if coeff.abs() >= 2 {
*coeff ^= 1;
modified_count += 1;
}
}
}
assert!(
modified_count > 0,
"test image had no eligible coefficients"
);
let modified_jpeg = write_coefficients(&jpeg, &coeffs).unwrap();
let coeffs2 = read_coefficients(&modified_jpeg).unwrap();
assert_eq!(coeffs.components[0].blocks, coeffs2.components[0].blocks);
}
#[test]
fn block_count_gray_16x16() {
let jpeg = make_jpeg_gray(16, 16);
let counts = block_count(&jpeg).unwrap();
assert_eq!(counts, vec![4]);
}
#[test]
fn block_count_rgb_32x32() {
let jpeg = make_jpeg_rgb(32, 32);
let counts = block_count(&jpeg).unwrap();
assert_eq!(counts.len(), 3);
let total: usize = counts.iter().sum();
assert!(total > 0);
}
#[test]
fn category_values() {
assert_eq!(category(0), 0);
assert_eq!(category(1), 1);
assert_eq!(category(-1), 1);
assert_eq!(category(2), 2);
assert_eq!(category(3), 2);
assert_eq!(category(4), 3);
assert_eq!(category(127), 7);
assert_eq!(category(-128), 8);
assert_eq!(category(1023), 10);
assert_eq!(category(i16::MAX), 15); }
#[test]
fn output_is_valid_jpeg() {
let jpeg = make_jpeg_rgb(24, 24);
let mut coeffs = read_coefficients(&jpeg).unwrap();
if let Some(block) = coeffs.components[0].blocks.first_mut() {
block[1] |= 1;
}
let out = write_coefficients(&jpeg, &coeffs).unwrap();
assert_eq!(&out[..2], &[0xFF, 0xD8], "missing SOI");
assert_eq!(&out[out.len() - 2..], &[0xFF, 0xD9], "missing EOI");
}
#[test]
fn inspect_gray_returns_correct_dimensions() {
let jpeg = make_jpeg_gray(32, 16);
let info = inspect(&jpeg).unwrap();
assert_eq!(info.width, 32);
assert_eq!(info.height, 16);
assert_eq!(info.components.len(), 1);
assert_eq!(info.components[0].block_count, 8); }
#[test]
fn inspect_rgb_returns_three_components() {
let jpeg = make_jpeg_rgb(32, 32);
let info = inspect(&jpeg).unwrap();
assert_eq!(info.width, 32);
assert_eq!(info.height, 32);
assert_eq!(info.components.len(), 3);
let total: usize = info.components.iter().map(|c| c.block_count).sum();
assert!(total > 0);
}
#[test]
fn inspect_matches_block_count() {
let jpeg = make_jpeg_rgb(48, 32);
let info = inspect(&jpeg).unwrap();
let counts = block_count(&jpeg).unwrap();
let info_counts: Vec<usize> = info.components.iter().map(|c| c.block_count).collect();
assert_eq!(info_counts, counts);
}
#[test]
fn eligible_ac_count_is_positive() {
let jpeg = make_jpeg_rgb(32, 32);
let n = eligible_ac_count(&jpeg).unwrap();
assert!(n > 0, "natural image should have eligible AC coefficients");
}
#[test]
fn eligible_ac_count_method_matches_free_fn() {
let jpeg = make_jpeg_gray(32, 32);
let coeffs = read_coefficients(&jpeg).unwrap();
let via_method = coeffs.eligible_ac_count();
let via_fn = eligible_ac_count(&jpeg).unwrap();
assert_eq!(via_method, via_fn);
}
#[test]
fn eligible_ac_count_leq_total_ac_count() {
let jpeg = make_jpeg_rgb(32, 32);
let coeffs = read_coefficients(&jpeg).unwrap();
let eligible = coeffs.eligible_ac_count();
let total_ac: usize = coeffs
.components
.iter()
.flat_map(|c| c.blocks.iter())
.map(|_| 63) .sum();
assert!(eligible <= total_ac);
}
#[test]
fn lut_decode_matches_modification_roundtrip() {
let jpeg = make_jpeg_rgb(64, 64);
let mut coeffs = read_coefficients(&jpeg).unwrap();
let mut flipped = 0usize;
for comp in &mut coeffs.components {
for block in &mut comp.blocks {
for coeff in block[1..].iter_mut() {
if coeff.abs() >= 2 {
*coeff ^= 1;
flipped += 1;
}
}
}
}
assert!(flipped > 0);
let modified = write_coefficients(&jpeg, &coeffs).unwrap();
let coeffs2 = read_coefficients(&modified).unwrap();
assert_eq!(coeffs.components.len(), coeffs2.components.len());
for (c1, c2) in coeffs.components.iter().zip(coeffs2.components.iter()) {
assert_eq!(c1.blocks, c2.blocks);
}
}
}