use oxideav_core::frame::VideoPlane;
use oxideav_core::Decoder;
use oxideav_core::{
CodecId, CodecParameters, Error, Frame, Packet, PixelFormat, Result, TimeBase, VideoFrame,
};
use crate::jpeg::dct::idct8x8;
use crate::jpeg::huffman::{parse_dht, HuffTable};
use crate::jpeg::markers::{self, *};
use crate::jpeg::parser::{parse_dri, parse_sof, parse_sos, MarkerWalker, SofInfo, SosInfo};
use crate::jpeg::quant::{parse_dqt, QuantTable};
use crate::jpeg::zigzag::ZIGZAG;
pub fn make_decoder(params: &CodecParameters) -> Result<Box<dyn Decoder>> {
let codec_id = params.codec_id.clone();
Ok(Box::new(MjpegDecoder {
codec_id,
pending: None,
eof: false,
}))
}
struct MjpegDecoder {
codec_id: CodecId,
pending: Option<Packet>,
eof: bool,
}
impl Decoder for MjpegDecoder {
fn codec_id(&self) -> &CodecId {
&self.codec_id
}
fn send_packet(&mut self, packet: &Packet) -> Result<()> {
if self.pending.is_some() {
return Err(Error::other(
"MJPEG decoder: receive_frame must be called before sending another packet",
));
}
self.pending = Some(packet.clone());
Ok(())
}
fn receive_frame(&mut self) -> Result<Frame> {
let Some(pkt) = self.pending.take() else {
return if self.eof {
Err(Error::Eof)
} else {
Err(Error::NeedMore)
};
};
let vf = decode_jpeg(&pkt.data, pkt.pts, pkt.time_base)?;
Ok(Frame::Video(vf))
}
fn flush(&mut self) -> Result<()> {
self.eof = true;
Ok(())
}
}
struct JpegState {
quant: [Option<QuantTable>; 4],
dc_huff: [Option<HuffTable>; 4],
ac_huff: [Option<HuffTable>; 4],
restart_interval: u16,
sof: Option<SofInfo>,
progressive: bool,
seq_accum: bool,
lossless: bool,
adobe_transform: Option<u8>,
}
impl JpegState {
fn new() -> Self {
Self {
quant: Default::default(),
dc_huff: Default::default(),
ac_huff: Default::default(),
restart_interval: 0,
sof: None,
progressive: false,
seq_accum: false,
lossless: false,
adobe_transform: None,
}
}
}
fn decode_jpeg(data: &[u8], pts: Option<i64>, time_base: TimeBase) -> Result<VideoFrame> {
if data.len() < 2 || data[0] != 0xFF || data[1] != markers::SOI {
return Err(Error::invalid("JPEG: missing SOI"));
}
let mut walker = MarkerWalker::new(&data[2..]);
let mut state = JpegState::new();
let mut coef_buf: Vec<Vec<[i32; 64]>> = Vec::new();
loop {
let Some(marker) = walker.next_marker()? else {
return Err(Error::invalid("JPEG: unexpected EOF before EOI"));
};
match marker {
EOI => {
if state.progressive || state.seq_accum {
return render_from_coefs(&state, &coef_buf, pts, time_base);
}
return Err(Error::invalid("JPEG: EOI before SOS"));
}
SOI => continue,
m if markers::is_rst(m) => continue,
DQT => {
let p = walker.read_segment_payload()?;
parse_dqt(p, &mut state.quant)?;
}
DHT => {
let p = walker.read_segment_payload()?;
parse_dht(p, &mut state.dc_huff, &mut state.ac_huff)?;
}
DRI => {
let p = walker.read_segment_payload()?;
state.restart_interval = parse_dri(p)?;
}
SOF0 | SOF1 => {
let p = walker.read_segment_payload()?;
state.sof = Some(parse_sof(p)?);
}
SOF2 => {
let p = walker.read_segment_payload()?;
let sof = parse_sof(p)?;
if sof.components.len() > 3 {
return Err(Error::unsupported(
"progressive JPEG: 4+ component scans not supported",
));
}
if sof.precision != 8 {
return Err(Error::unsupported(format!(
"progressive JPEG: precision {} (only 8 is supported)",
sof.precision
)));
}
coef_buf = init_coef_buffers(&sof)?;
state.sof = Some(sof);
state.progressive = true;
}
SOF3 => {
let p = walker.read_segment_payload()?;
let sof = parse_sof(p)?;
if !(2..=16).contains(&sof.precision) {
return Err(Error::unsupported(format!(
"lossless JPEG: precision {} out of range 2..=16",
sof.precision
)));
}
if sof.components.len() != 1 {
return Err(Error::unsupported(
"lossless JPEG: only single-component (grayscale) scans are supported",
));
}
state.sof = Some(sof);
state.lossless = true;
}
0xC5..=0xC7 | 0xC9..=0xCB | 0xCD..=0xCF => {
let _ = walker.read_segment_payload();
return Err(Error::unsupported(
"JPEG: hierarchical and arithmetic-coded variants are not supported",
));
}
SOS => {
let p = walker.read_segment_payload()?;
let sos = parse_sos(p)?;
let scan = walker.read_scan_data()?;
if state.lossless {
return decode_lossless_scan(&state, &sos, scan, pts, time_base);
}
if state.progressive {
decode_progressive_scan(&state, &sos, scan, &mut coef_buf)?;
} else {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
let fully_interleaved = sos.components.len() == sof.components.len();
let fast_path_ok = fully_interleaved
&& !state.seq_accum
&& sof.components.len() <= 3
&& sof.precision == 8;
if fast_path_ok {
return decode_scan(&state, &sos, scan, pts, time_base);
}
if !state.seq_accum {
coef_buf = init_coef_buffers(sof)?;
state.seq_accum = true;
}
decode_sequential_scan_accum(&state, &sos, scan, &mut coef_buf)?;
}
}
COM => {
let _ = walker.read_segment_payload()?;
}
markers::APP14 => {
let p = walker.read_segment_payload()?;
if p.len() >= 12 && &p[0..5] == b"Adobe" {
state.adobe_transform = Some(p[11]);
}
}
m if markers::is_app(m) => {
let _ = walker.read_segment_payload()?;
}
_ => {
let _ = walker.read_segment_payload();
}
}
}
}
fn init_coef_buffers(sof: &SofInfo) -> Result<Vec<Vec<[i32; 64]>>> {
if sof.precision != 8 && sof.precision != 12 {
return Err(Error::unsupported(format!(
"coef accumulator: precision {} (only 8 and 12 are supported)",
sof.precision
)));
}
if sof.components.is_empty() {
return Err(Error::invalid("SOF: no components"));
}
if sof.components.len() > 4 {
return Err(Error::unsupported("coef accumulator: >4 components"));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1);
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1);
if h_max == 0 || v_max == 0 {
return Err(Error::invalid("SOF: sampling factor = 0"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let mcu_w_px = 8 * h_max as usize;
let mcu_h_px = 8 * v_max as usize;
let mcus_x = width.div_ceil(mcu_w_px);
let mcus_y = height.div_ceil(mcu_h_px);
let mut out = Vec::with_capacity(sof.components.len());
for c in &sof.components {
let blocks_x = mcus_x * c.h_factor as usize;
let blocks_y = mcus_y * c.v_factor as usize;
out.push(vec![[0i32; 64]; blocks_x * blocks_y]);
}
Ok(out)
}
struct BitReader<'a> {
buf: &'a [u8],
pos: usize,
bits: u32,
nbits: u32,
pub saw_rst: Option<u8>,
}
impl<'a> BitReader<'a> {
fn new(buf: &'a [u8]) -> Self {
Self {
buf,
pos: 0,
bits: 0,
nbits: 0,
saw_rst: None,
}
}
fn next_byte_with_stuff(&mut self) -> Result<Option<u8>> {
if self.pos >= self.buf.len() {
return Ok(None);
}
let b = self.buf[self.pos];
self.pos += 1;
if b == 0xFF {
while self.pos < self.buf.len() && self.buf[self.pos] == 0xFF {
self.pos += 1;
}
if self.pos >= self.buf.len() {
return Err(Error::invalid("scan: 0xFF at end without followup"));
}
let next = self.buf[self.pos];
self.pos += 1;
if next == 0x00 {
return Ok(Some(0xFF));
}
if markers::is_rst(next) {
self.saw_rst = Some(next);
return Ok(None);
}
self.pos -= 2;
return Ok(None);
}
Ok(Some(b))
}
fn fill(&mut self, needed: u32) -> Result<()> {
while self.nbits < needed {
match self.next_byte_with_stuff()? {
Some(b) => {
self.bits |= (b as u32) << (24 - self.nbits);
self.nbits += 8;
}
None => {
self.bits |= 0;
self.nbits = needed;
break;
}
}
}
Ok(())
}
fn get_bits(&mut self, n: u32) -> Result<u32> {
self.fill(n)?;
let v = self.bits >> (32 - n);
self.bits <<= n;
self.nbits -= n;
Ok(v)
}
fn reset_at_restart(&mut self) {
self.bits = 0;
self.nbits = 0;
self.saw_rst = None;
}
}
fn decode_huff(br: &mut BitReader<'_>, t: &HuffTable) -> Result<u8> {
let mut code: i32 = 0;
for l in 0..16 {
let bit = br.get_bits(1)? as i32;
code = (code << 1) | bit;
if code <= t.max_code[l] {
let idx = (t.val_offset[l] + code) as usize;
if idx >= t.values.len() {
return Err(Error::invalid("huffman: value index OOB"));
}
return Ok(t.values[idx]);
}
}
Err(Error::invalid("huffman: no matching code (length > 16)"))
}
fn extend(value: i32, size: u32) -> i32 {
if size == 0 {
return 0;
}
let vt = 1 << (size - 1);
if value < vt {
value - ((1 << size) - 1)
} else {
value
}
}
fn decode_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
pts: Option<i64>,
_time_base: TimeBase,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sof.precision != 8 {
return Err(Error::unsupported("precision != 8"));
}
if sof.components.is_empty() {
return Err(Error::invalid("SOF: no components"));
}
if sof.components.len() > 3 {
return Err(Error::unsupported("4+ components"));
}
if sos.components.len() != sof.components.len() {
return Err(Error::unsupported("non-interleaved scan"));
}
let n_comp = sof.components.len();
let grayscale = n_comp == 1;
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1);
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1);
if h_max == 0 || v_max == 0 {
return Err(Error::invalid("SOF: sampling factor = 0"));
}
let pix_fmt = if grayscale {
PixelFormat::Gray8
} else if n_comp == 3 {
let y = sof.components[0];
let cb = sof.components[1];
let cr = sof.components[2];
if cb.h_factor != cr.h_factor || cb.v_factor != cr.v_factor {
return Err(Error::unsupported(
"chroma components have different sampling factors",
));
}
if cb.h_factor != 1 || cb.v_factor != 1 {
return Err(Error::unsupported(
"chroma components must have factor 1 (luma carries the oversampling)",
));
}
match (y.h_factor, y.v_factor) {
(1, 1) => PixelFormat::Yuv444P,
(2, 1) => PixelFormat::Yuv422P,
(2, 2) => PixelFormat::Yuv420P,
_ => {
return Err(Error::unsupported(format!(
"luma sampling {}x{}",
y.h_factor, y.v_factor
)))
}
}
} else {
return Err(Error::unsupported("2-component JPEG"));
};
let width = sof.width as usize;
let height = sof.height as usize;
let mcu_w_px = 8 * h_max as usize;
let mcu_h_px = 8 * v_max as usize;
let mcus_x = width.div_ceil(mcu_w_px);
let mcus_y = height.div_ceil(mcu_h_px);
let mut comp_buf: Vec<Vec<u8>> = Vec::with_capacity(n_comp);
let mut comp_stride: Vec<usize> = Vec::with_capacity(n_comp);
let mut comp_w_full: Vec<usize> = Vec::with_capacity(n_comp);
let mut comp_h_full: Vec<usize> = Vec::with_capacity(n_comp);
for c in &sof.components {
let w_full = mcus_x * 8 * c.h_factor as usize;
let h_full = mcus_y * 8 * c.v_factor as usize;
comp_buf.push(vec![0u8; w_full * h_full]);
comp_stride.push(w_full);
comp_w_full.push(w_full);
comp_h_full.push(h_full);
}
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("SOS: component id not in SOF"))
})
.collect::<Result<Vec<_>>>()?;
let dc_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.dc_huff[sc.dc_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: DC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let ac_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.ac_huff[sc.ac_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: AC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let quant_tables: Vec<&QuantTable> = sof
.components
.iter()
.map(|c| {
state.quant[c.qt_id as usize]
.as_ref()
.ok_or_else(|| Error::invalid("quant table missing for component"))
})
.collect::<Result<Vec<_>>>()?;
let mut br = BitReader::new(scan);
let mut prev_dc = vec![0i32; n_comp];
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
for my in 0..mcus_y {
for mx in 0..mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {
}
None => {
if br.saw_rst.is_none() {
if prev == br.pos {
break;
}
}
break;
}
}
}
if let Some(m) = br.saw_rst {
if m != expected_rst {
}
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
for p in prev_dc.iter_mut() {
*p = 0;
}
br.reset_at_restart();
}
}
for (sidx, sof_idx) in sos_map.iter().enumerate() {
let c = sof.components[*sof_idx];
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let mut block = [0i32; 64];
decode_block(
&mut br,
dc_tables[sidx],
ac_tables[sidx],
&mut prev_dc[*sof_idx],
&mut block,
)?;
let qt = quant_tables[*sof_idx];
let mut fblock = [0.0f32; 64];
for k in 0..64 {
fblock[k] = (block[k] * qt.values[k] as i32) as f32;
}
idct8x8(&mut fblock);
let dst_x0 = mx * 8 * c.h_factor as usize + bx * 8;
let dst_y0 = my * 8 * c.v_factor as usize + by * 8;
let stride = comp_stride[*sof_idx];
let buf = &mut comp_buf[*sof_idx];
for j in 0..8 {
for i in 0..8 {
let v = fblock[j * 8 + i] + 128.0;
let px = if v <= 0.0 {
0
} else if v >= 255.0 {
255
} else {
v.round() as u8
};
buf[(dst_y0 + j) * stride + dst_x0 + i] = px;
}
}
}
}
}
mcus_since_restart += 1;
}
}
let out_format = pix_fmt;
let mut planes: Vec<VideoPlane> = Vec::new();
match out_format {
PixelFormat::Gray8 => {
let stride = width;
let mut data = vec![0u8; stride * height];
let src_stride = comp_stride[0];
for y in 0..height {
data[y * stride..y * stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride..y * src_stride + width]);
}
planes.push(VideoPlane { stride, data });
}
PixelFormat::Yuv444P | PixelFormat::Yuv422P | PixelFormat::Yuv420P => {
let (c_w, c_h) = match out_format {
PixelFormat::Yuv444P => (width, height),
PixelFormat::Yuv422P => (width.div_ceil(2), height),
PixelFormat::Yuv420P => (width.div_ceil(2), height.div_ceil(2)),
_ => unreachable!(),
};
let y_stride = width;
let mut y_data = vec![0u8; y_stride * height];
let src_stride_y = comp_stride[0];
for y in 0..height {
y_data[y * y_stride..y * y_stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride_y..y * src_stride_y + width]);
}
planes.push(VideoPlane {
stride: y_stride,
data: y_data,
});
for ci in [1usize, 2] {
let src_stride = comp_stride[ci];
let stride = c_w;
let mut data = vec![0u8; stride * c_h];
for y in 0..c_h {
data[y * stride..y * stride + c_w]
.copy_from_slice(&comp_buf[ci][y * src_stride..y * src_stride + c_w]);
}
planes.push(VideoPlane { stride, data });
}
}
_ => unreachable!(),
}
Ok(VideoFrame { pts, planes })
}
fn decode_block(
br: &mut BitReader<'_>,
dc: &HuffTable,
ac: &HuffTable,
prev_dc: &mut i32,
out_natural: &mut [i32; 64],
) -> Result<()> {
let t = decode_huff(br, dc)? as u32;
let dc_diff = if t == 0 {
0
} else {
let bits = br.get_bits(t)? as i32;
extend(bits, t)
};
*prev_dc = prev_dc.wrapping_add(dc_diff);
out_natural[0] = *prev_dc;
let mut k: usize = 1;
while k < 64 {
let rs = decode_huff(br, ac)?;
let run = (rs >> 4) as usize;
let size = (rs & 0x0F) as u32;
if size == 0 {
if run == 15 {
k += 16;
continue;
}
break;
}
k += run;
if k >= 64 {
return Err(Error::invalid("JPEG AC: run out of block"));
}
let bits = br.get_bits(size)? as i32;
let val = extend(bits, size);
out_natural[ZIGZAG[k]] = val;
k += 1;
}
Ok(())
}
fn decode_sequential_scan_accum(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
coefs: &mut [Vec<[i32; 64]>],
) -> Result<()> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sof.precision != 8 && sof.precision != 12 {
return Err(Error::unsupported(format!(
"scan: unsupported precision {}",
sof.precision
)));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
if h_max == 0 || v_max == 0 {
return Err(Error::invalid("SOF: sampling factor = 0"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("SOS: component id not in SOF"))
})
.collect::<Result<Vec<_>>>()?;
let dc_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.dc_huff[sc.dc_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: DC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let ac_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.ac_huff[sc.ac_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: AC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let interleaved = sos.components.len() == sof.components.len();
let (scan_mcus_x, scan_mcus_y) = if interleaved {
(mcus_x, mcus_y)
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
};
let mut br = BitReader::new(scan);
let mut prev_dc = vec![0i32; sos.components.len()];
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
for my in 0..scan_mcus_y {
for mx in 0..scan_mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {}
None => {
if prev == br.pos {
break;
}
break;
}
}
}
if br.saw_rst.is_some() {
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
for p in prev_dc.iter_mut() {
*p = 0;
}
br.reset_at_restart();
}
}
if interleaved {
for (sidx, &sof_idx) in sos_map.iter().enumerate() {
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let bidx_x = mx * c.h_factor as usize + bx;
let bidx_y = my * c.v_factor as usize + by;
let bi = bidx_y * blocks_x + bidx_x;
decode_block(
&mut br,
dc_tables[sidx],
ac_tables[sidx],
&mut prev_dc[sidx],
&mut coefs[sof_idx][bi],
)?;
}
}
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
decode_block(
&mut br,
dc_tables[0],
ac_tables[0],
&mut prev_dc[0],
&mut coefs[sof_idx][bi],
)?;
}
mcus_since_restart += 1;
}
}
Ok(())
}
fn decode_progressive_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
coefs: &mut [Vec<[i32; 64]>],
) -> Result<()> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sos.ss > 63 || sos.se > 63 || sos.ss > sos.se {
return Err(Error::invalid("progressive: invalid Ss/Se"));
}
let is_dc_scan = sos.ss == 0;
if is_dc_scan && sos.se != 0 {
return Err(Error::invalid("progressive: DC scan must have Se=0"));
}
if !is_dc_scan && sos.components.len() != 1 {
return Err(Error::invalid(
"progressive: AC scans must be non-interleaved",
));
}
if sos.ah > 13 || sos.al > 13 {
return Err(Error::invalid("progressive: Ah/Al out of range"));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
let width = sof.width as usize;
let height = sof.height as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("progressive SOS: unknown component"))
})
.collect::<Result<Vec<_>>>()?;
let dc_tables: Vec<Option<&HuffTable>> = sos
.components
.iter()
.map(|sc| {
if is_dc_scan {
state.dc_huff[sc.dc_table as usize].as_ref()
} else {
None
}
})
.collect();
let ac_tables: Vec<Option<&HuffTable>> = sos
.components
.iter()
.map(|sc| {
if is_dc_scan {
None
} else {
state.ac_huff[sc.ac_table as usize].as_ref()
}
})
.collect();
if is_dc_scan && dc_tables.iter().any(|t| t.is_none()) {
return Err(Error::invalid("progressive DC: Huffman table missing"));
}
if !is_dc_scan && ac_tables[0].is_none() {
return Err(Error::invalid("progressive AC: Huffman table missing"));
}
let mut br = BitReader::new(scan);
let mut prev_dc = vec![0i32; sos.components.len()];
let mut eob_run: u32 = 0;
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
let (scan_mcus_x, scan_mcus_y) = if is_dc_scan && sos.components.len() > 1 {
(mcus_x, mcus_y)
} else {
let ci = sos_map[0];
let c = sof.components[ci];
if sos.components.len() == 1 && sof.components.len() == 1 {
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
} else {
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
}
};
for my in 0..scan_mcus_y {
for mx in 0..scan_mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {}
None => {
if prev == br.pos {
break;
}
break;
}
}
}
if br.saw_rst.is_some() {
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
for p in prev_dc.iter_mut() {
*p = 0;
}
eob_run = 0;
br.reset_at_restart();
}
}
if is_dc_scan {
if sos.components.len() > 1 {
for (sidx, &sof_idx) in sos_map.iter().enumerate() {
let c = sof.components[sof_idx];
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let blocks_x = mcus_x * c.h_factor as usize;
let bidx_x = mx * c.h_factor as usize + bx;
let bidx_y = my * c.v_factor as usize + by;
let bi = bidx_y * blocks_x + bidx_x;
prog_decode_dc(
&mut br,
dc_tables[sidx].unwrap(),
&mut prev_dc[sidx],
&mut coefs[sof_idx][bi],
sos.ah,
sos.al,
)?;
}
}
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
prog_decode_dc(
&mut br,
dc_tables[0].unwrap(),
&mut prev_dc[0],
&mut coefs[sof_idx][bi],
sos.ah,
sos.al,
)?;
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
let ac = ac_tables[0].unwrap();
if sos.ah == 0 {
prog_decode_ac_first(
&mut br,
ac,
&mut coefs[sof_idx][bi],
sos.ss as usize,
sos.se as usize,
sos.al,
&mut eob_run,
)?;
} else {
prog_decode_ac_refine(
&mut br,
ac,
&mut coefs[sof_idx][bi],
sos.ss as usize,
sos.se as usize,
sos.al,
&mut eob_run,
)?;
}
}
mcus_since_restart += 1;
}
}
Ok(())
}
fn prog_decode_dc(
br: &mut BitReader<'_>,
dc: &HuffTable,
prev_dc: &mut i32,
block: &mut [i32; 64],
ah: u8,
al: u8,
) -> Result<()> {
if ah == 0 {
let t = decode_huff(br, dc)? as u32;
let dc_diff = if t == 0 {
0
} else {
let bits = br.get_bits(t)? as i32;
extend(bits, t)
};
*prev_dc = prev_dc.wrapping_add(dc_diff);
block[0] = *prev_dc << al;
} else {
let bit = br.get_bits(1)? as i32;
block[0] |= bit << al;
}
Ok(())
}
fn prog_decode_ac_first(
br: &mut BitReader<'_>,
ac: &HuffTable,
block: &mut [i32; 64],
ss: usize,
se: usize,
al: u8,
eob_run: &mut u32,
) -> Result<()> {
if *eob_run > 0 {
*eob_run -= 1;
return Ok(());
}
let mut k = ss;
while k <= se {
let rs = decode_huff(br, ac)?;
let r = (rs >> 4) as usize;
let s = (rs & 0x0F) as u32;
if s == 0 {
if r != 15 {
let extra = if r == 0 { 0 } else { br.get_bits(r as u32)? };
*eob_run = (1u32 << r) + extra - 1;
return Ok(());
}
k += 16;
} else {
k += r;
if k > se {
return Err(Error::invalid("progressive AC: run out of band"));
}
let bits = br.get_bits(s)? as i32;
let val = extend(bits, s) << al;
block[ZIGZAG[k]] = val;
k += 1;
}
}
Ok(())
}
fn prog_decode_ac_refine(
br: &mut BitReader<'_>,
ac: &HuffTable,
block: &mut [i32; 64],
ss: usize,
se: usize,
al: u8,
eob_run: &mut u32,
) -> Result<()> {
let p1: i32 = 1 << al;
let m1: i32 = -1 << al;
let mut k = ss;
if *eob_run == 0 {
while k <= se {
let rs = decode_huff(br, ac)?;
let mut r = (rs >> 4) as usize;
let s = (rs & 0x0F) as u32;
let new_val: i32;
if s == 0 {
if r != 15 {
let extra = if r == 0 { 0 } else { br.get_bits(r as u32)? };
*eob_run = (1u32 << r) + extra;
break;
}
new_val = 0;
} else if s == 1 {
let sign_bit = br.get_bits(1)?;
new_val = if sign_bit == 0 { m1 } else { p1 };
} else {
return Err(Error::invalid("progressive AC refine: bad s"));
}
loop {
if k > se {
return Err(Error::invalid("progressive AC refine: k past se"));
}
let pos = ZIGZAG[k];
if block[pos] != 0 {
let bit = br.get_bits(1)? as i32;
if bit != 0 && (block[pos] & p1) == 0 {
if block[pos] >= 0 {
block[pos] += p1;
} else {
block[pos] += m1;
}
}
} else if r == 0 {
break;
} else {
r -= 1;
}
k += 1;
}
if new_val != 0 && k <= se {
block[ZIGZAG[k]] = new_val;
}
k += 1;
}
}
if *eob_run > 0 {
while k <= se {
let pos = ZIGZAG[k];
if block[pos] != 0 {
let bit = br.get_bits(1)? as i32;
if bit != 0 && (block[pos] & p1) == 0 {
if block[pos] >= 0 {
block[pos] += p1;
} else {
block[pos] += m1;
}
}
}
k += 1;
}
*eob_run -= 1;
}
Ok(())
}
fn render_from_coefs(
state: &JpegState,
coefs: &[Vec<[i32; 64]>],
pts: Option<i64>,
time_base: TimeBase,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("render: EOI before SOF"))?;
if sof.precision == 12 {
return render_from_coefs_12bit(state, coefs, pts, time_base);
}
let _ = time_base;
let n_comp = sof.components.len();
let grayscale = n_comp == 1;
let width = sof.width as usize;
let height = sof.height as usize;
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let out_format = if grayscale {
PixelFormat::Gray8
} else if n_comp == 3 {
let y = sof.components[0];
let cb = sof.components[1];
let cr = sof.components[2];
if cb.h_factor != cr.h_factor || cb.v_factor != cr.v_factor {
return Err(Error::unsupported(
"chroma components have different sampling factors",
));
}
if cb.h_factor != 1 || cb.v_factor != 1 {
return Err(Error::unsupported("chroma components must have factor 1"));
}
match (y.h_factor, y.v_factor) {
(1, 1) => PixelFormat::Yuv444P,
(2, 1) => PixelFormat::Yuv422P,
(2, 2) => PixelFormat::Yuv420P,
_ => {
return Err(Error::unsupported(format!(
"luma sampling {}x{}",
y.h_factor, y.v_factor
)))
}
}
} else if n_comp == 4 {
PixelFormat::Cmyk
} else {
return Err(Error::unsupported("2-component JPEG"));
};
let quant_tables: Vec<&QuantTable> = sof
.components
.iter()
.map(|c| {
state.quant[c.qt_id as usize]
.as_ref()
.ok_or_else(|| Error::invalid("quant table missing for component"))
})
.collect::<Result<Vec<_>>>()?;
let mut comp_buf: Vec<Vec<u8>> = Vec::with_capacity(n_comp);
let mut comp_stride: Vec<usize> = Vec::with_capacity(n_comp);
for c in &sof.components {
let w_full = mcus_x * 8 * c.h_factor as usize;
let h_full = mcus_y * 8 * c.v_factor as usize;
comp_buf.push(vec![0u8; w_full * h_full]);
comp_stride.push(w_full);
}
for (ci, c) in sof.components.iter().enumerate() {
let blocks_x = mcus_x * c.h_factor as usize;
let blocks_y = mcus_y * c.v_factor as usize;
let qt = quant_tables[ci];
let stride = comp_stride[ci];
let buf = &mut comp_buf[ci];
for by in 0..blocks_y {
for bx in 0..blocks_x {
let block = &coefs[ci][by * blocks_x + bx];
let mut fblock = [0.0f32; 64];
for k in 0..64 {
fblock[k] = (block[k] * qt.values[k] as i32) as f32;
}
idct8x8(&mut fblock);
let dst_x0 = bx * 8;
let dst_y0 = by * 8;
for j in 0..8 {
for i in 0..8 {
let v = fblock[j * 8 + i] + 128.0;
let px = if v <= 0.0 {
0
} else if v >= 255.0 {
255
} else {
v.round() as u8
};
buf[(dst_y0 + j) * stride + dst_x0 + i] = px;
}
}
}
}
}
let mut planes: Vec<VideoPlane> = Vec::new();
match out_format {
PixelFormat::Gray8 => {
let stride = width;
let mut data = vec![0u8; stride * height];
let src_stride = comp_stride[0];
for y in 0..height {
data[y * stride..y * stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride..y * src_stride + width]);
}
planes.push(VideoPlane { stride, data });
}
PixelFormat::Yuv444P | PixelFormat::Yuv422P | PixelFormat::Yuv420P => {
let (c_w, c_h) = match out_format {
PixelFormat::Yuv444P => (width, height),
PixelFormat::Yuv422P => (width.div_ceil(2), height),
PixelFormat::Yuv420P => (width.div_ceil(2), height.div_ceil(2)),
_ => unreachable!(),
};
let y_stride = width;
let mut y_data = vec![0u8; y_stride * height];
let src_stride_y = comp_stride[0];
for y in 0..height {
y_data[y * y_stride..y * y_stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride_y..y * src_stride_y + width]);
}
planes.push(VideoPlane {
stride: y_stride,
data: y_data,
});
for ci in [1usize, 2] {
let src_stride = comp_stride[ci];
let stride = c_w;
let mut data = vec![0u8; stride * c_h];
for y in 0..c_h {
data[y * stride..y * stride + c_w]
.copy_from_slice(&comp_buf[ci][y * src_stride..y * src_stride + c_w]);
}
planes.push(VideoPlane { stride, data });
}
}
PixelFormat::Cmyk => {
let stride = width * 4;
let mut data = vec![0u8; stride * height];
let fh: [usize; 4] = [
sof.components[0].h_factor as usize,
sof.components[1].h_factor as usize,
sof.components[2].h_factor as usize,
sof.components[3].h_factor as usize,
];
let fv: [usize; 4] = [
sof.components[0].v_factor as usize,
sof.components[1].v_factor as usize,
sof.components[2].v_factor as usize,
sof.components[3].v_factor as usize,
];
let transform = state.adobe_transform;
for y in 0..height {
for x in 0..width {
let mut s = [0u8; 4];
for ci in 0..4 {
let sx = x * fh[ci] / h_max;
let sy = y * fv[ci] / v_max;
s[ci] = comp_buf[ci][sy * comp_stride[ci] + sx];
}
let (c, m, yy, k) = match transform {
Some(2) => {
let y_s = s[0] as i32;
let cb = s[1] as i32 - 128;
let cr = s[2] as i32 - 128;
let r = (y_s + ((cr * 91881 + 32768) >> 16)).clamp(0, 255);
let g = (y_s - ((cb * 22554 + cr * 46802 + 32768) >> 16)).clamp(0, 255);
let b = (y_s + ((cb * 116130 + 32768) >> 16)).clamp(0, 255);
(
(255 - r) as u8,
(255 - g) as u8,
(255 - b) as u8,
255 - s[3],
)
}
Some(0) => {
(255 - s[0], 255 - s[1], 255 - s[2], 255 - s[3])
}
_ => {
(s[0], s[1], s[2], s[3])
}
};
let o = y * stride + x * 4;
data[o] = c;
data[o + 1] = m;
data[o + 2] = yy;
data[o + 3] = k;
}
}
planes.push(VideoPlane { stride, data });
}
_ => unreachable!(),
}
Ok(VideoFrame { pts, planes })
}
fn render_from_coefs_12bit(
state: &JpegState,
coefs: &[Vec<[i32; 64]>],
pts: Option<i64>,
_time_base: TimeBase,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("render: EOI before SOF"))?;
let n_comp = sof.components.len();
let grayscale = n_comp == 1;
let width = sof.width as usize;
let height = sof.height as usize;
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let out_format = if grayscale {
PixelFormat::Gray12Le
} else if n_comp == 3 {
let y = sof.components[0];
let cb = sof.components[1];
let cr = sof.components[2];
if cb.h_factor != cr.h_factor || cb.v_factor != cr.v_factor {
return Err(Error::unsupported(
"12-bit: chroma components have different sampling factors",
));
}
if cb.h_factor != 1 || cb.v_factor != 1 {
return Err(Error::unsupported(
"12-bit: chroma components must have factor 1",
));
}
match (y.h_factor, y.v_factor) {
(2, 2) => PixelFormat::Yuv420P12Le,
_ => {
return Err(Error::unsupported(format!(
"12-bit: only 4:2:0 chroma sampling supported (got {}x{})",
y.h_factor, y.v_factor
)))
}
}
} else {
return Err(Error::unsupported(format!(
"12-bit: {n_comp}-component JPEGs not supported"
)));
};
let quant_tables: Vec<&QuantTable> = sof
.components
.iter()
.map(|c| {
state.quant[c.qt_id as usize]
.as_ref()
.ok_or_else(|| Error::invalid("quant table missing for component"))
})
.collect::<Result<Vec<_>>>()?;
let mut comp_buf: Vec<Vec<u16>> = Vec::with_capacity(n_comp);
let mut comp_stride: Vec<usize> = Vec::with_capacity(n_comp);
for c in &sof.components {
let w_full = mcus_x * 8 * c.h_factor as usize;
let h_full = mcus_y * 8 * c.v_factor as usize;
comp_buf.push(vec![0u16; w_full * h_full]);
comp_stride.push(w_full);
}
for (ci, c) in sof.components.iter().enumerate() {
let blocks_x = mcus_x * c.h_factor as usize;
let blocks_y = mcus_y * c.v_factor as usize;
let qt = quant_tables[ci];
let stride = comp_stride[ci];
let buf = &mut comp_buf[ci];
for by in 0..blocks_y {
for bx in 0..blocks_x {
let block = &coefs[ci][by * blocks_x + bx];
let mut fblock = [0.0f32; 64];
for k in 0..64 {
fblock[k] = (block[k] * qt.values[k] as i32) as f32;
}
idct8x8(&mut fblock);
let dst_x0 = bx * 8;
let dst_y0 = by * 8;
for j in 0..8 {
for i in 0..8 {
let v = fblock[j * 8 + i] + 2048.0;
let px = if v <= 0.0 {
0
} else if v >= 4095.0 {
4095
} else {
v.round() as u16
};
buf[(dst_y0 + j) * stride + dst_x0 + i] = px;
}
}
}
}
}
let mut planes: Vec<VideoPlane> = Vec::new();
let emit_plane = |src: &[u16], src_stride: usize, w: usize, h: usize| -> VideoPlane {
let stride = w * 2;
let mut data = vec![0u8; stride * h];
for y in 0..h {
for x in 0..w {
let v = src[y * src_stride + x];
data[y * stride + x * 2] = (v & 0xFF) as u8;
data[y * stride + x * 2 + 1] = ((v >> 8) & 0xFF) as u8;
}
}
VideoPlane { stride, data }
};
match out_format {
PixelFormat::Gray12Le => {
planes.push(emit_plane(&comp_buf[0], comp_stride[0], width, height));
}
PixelFormat::Yuv420P12Le => {
planes.push(emit_plane(&comp_buf[0], comp_stride[0], width, height));
let c_w = width.div_ceil(2);
let c_h = height.div_ceil(2);
planes.push(emit_plane(&comp_buf[1], comp_stride[1], c_w, c_h));
planes.push(emit_plane(&comp_buf[2], comp_stride[2], c_w, c_h));
}
_ => unreachable!(),
}
Ok(VideoFrame { pts, planes })
}
fn decode_lossless_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
pts: Option<i64>,
_time_base: TimeBase,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sos.components.len() != 1 {
return Err(Error::unsupported(
"lossless: multi-component scans are not supported",
));
}
let predictor = sos.ss;
if !(1..=7).contains(&predictor) {
return Err(Error::invalid("lossless: predictor Ss must be in 1..=7"));
}
let pt = sos.al as u32; let precision = sof.precision as u32;
if pt >= precision {
return Err(Error::invalid("lossless: Pt >= precision"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let dc_t = state.dc_huff[sos.components[0].dc_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("lossless: DC Huffman table missing"))?;
let sample_bits = precision - pt;
let sample_max: u32 = 1u32 << sample_bits;
let sample_mask: u32 = sample_max - 1;
let origin: u32 = 1u32 << (sample_bits - 1);
let mut samples = vec![0u32; width * height];
let mut br = BitReader::new(scan);
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
let mut reset_pred = true;
for y in 0..height {
for x in 0..width {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {}
None => {
if prev == br.pos {
break;
}
break;
}
}
}
if br.saw_rst.is_some() {
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
br.reset_at_restart();
reset_pred = true;
}
}
let pred: u32 = if reset_pred {
reset_pred = false;
origin
} else if y == 0 {
samples[y * width + x - 1]
} else if x == 0 {
samples[(y - 1) * width + x]
} else {
let ra = samples[y * width + x - 1];
let rb = samples[(y - 1) * width + x];
let rc = samples[(y - 1) * width + x - 1];
match predictor {
1 => ra,
2 => rb,
3 => rc,
4 => ra.wrapping_add(rb).wrapping_sub(rc),
5 => ra.wrapping_add(rb.wrapping_sub(rc) >> 1),
6 => rb.wrapping_add(ra.wrapping_sub(rc) >> 1),
7 => (ra.wrapping_add(rb)) >> 1,
_ => unreachable!(),
}
};
let s = decode_huff(&mut br, dc_t)? as u32;
let residual: i32 = if s == 0 {
0
} else if s == 16 {
32_768
} else {
let bits = br.get_bits(s)? as i32;
extend(bits, s)
};
let sv = ((pred as i32).wrapping_add(residual) as u32) & sample_mask;
samples[y * width + x] = sv;
mcus_since_restart += 1;
}
}
let out_format = match precision {
8 => PixelFormat::Gray8,
10 => PixelFormat::Gray10Le,
12 => PixelFormat::Gray12Le,
_ => PixelFormat::Gray16Le,
};
let plane = if out_format == PixelFormat::Gray8 {
let stride = width;
let mut data = vec![0u8; stride * height];
for i in 0..width * height {
data[i] = (samples[i] << pt) as u8;
}
VideoPlane { stride, data }
} else {
let stride = width * 2;
let mut data = vec![0u8; stride * height];
for i in 0..width * height {
let v = (samples[i] << pt) as u16;
data[i * 2] = (v & 0xFF) as u8;
data[i * 2 + 1] = (v >> 8) as u8;
}
VideoPlane { stride, data }
};
Ok(VideoFrame {
pts,
planes: vec![plane],
})
}
#[cfg(test)]
mod prog_tests {
use super::*;
use crate::jpeg::huffman::{HuffTable, STD_DC_LUMA_BITS, STD_DC_LUMA_VALS};
use crate::jpeg::parser::SofComponent;
fn make_state(progressive: bool, width: u16, height: u16) -> JpegState {
let mut s = JpegState::new();
s.sof = Some(SofInfo {
precision: 8,
height,
width,
components: vec![SofComponent {
id: 1,
h_factor: 1,
v_factor: 1,
qt_id: 0,
}],
});
s.progressive = progressive;
s.quant[0] = Some(QuantTable { values: [1u16; 64] });
s.dc_huff[0] = Some(HuffTable::build(&STD_DC_LUMA_BITS, &STD_DC_LUMA_VALS).unwrap());
s
}
#[test]
fn accumulator_sizes() {
let sof = SofInfo {
precision: 8,
height: 16,
width: 16,
components: vec![
SofComponent {
id: 1,
h_factor: 2,
v_factor: 2,
qt_id: 0,
},
SofComponent {
id: 2,
h_factor: 1,
v_factor: 1,
qt_id: 1,
},
SofComponent {
id: 3,
h_factor: 1,
v_factor: 1,
qt_id: 1,
},
],
};
let coefs = init_coef_buffers(&sof).unwrap();
assert_eq!(coefs[0].len(), 4);
assert_eq!(coefs[1].len(), 1);
assert_eq!(coefs[2].len(), 1);
for blk in &coefs[0] {
assert!(blk.iter().all(|&v| v == 0));
}
}
#[test]
fn dc_first_pass_shifts_by_al() {
let t = HuffTable::build(&STD_DC_LUMA_BITS, &STD_DC_LUMA_VALS).unwrap();
let code = t.encode[3];
let mut bw = ProgTestBitWriter::new();
bw.put(code.code as u32, code.len as u32);
bw.put(0b101, 3);
let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut prev_dc = 0i32;
let mut block = [0i32; 64];
prog_decode_dc(&mut br, &t, &mut prev_dc, &mut block, 0, 2).unwrap();
assert_eq!(prev_dc, 5);
assert_eq!(block[0], 20);
}
#[test]
fn dc_refine_appends_bit() {
let mut bw = ProgTestBitWriter::new();
bw.put(1, 1);
let buf = bw.finish();
let t = HuffTable::build(&STD_DC_LUMA_BITS, &STD_DC_LUMA_VALS).unwrap();
let mut br = BitReader::new(&buf);
let mut prev_dc = 0i32;
let mut block = [0i32; 64];
block[0] = 0b1000;
prog_decode_dc(&mut br, &t, &mut prev_dc, &mut block, 1, 0).unwrap();
assert_eq!(block[0], 0b1001);
}
#[test]
fn state_helper_is_coherent() {
let s = make_state(true, 8, 8);
assert!(s.progressive);
assert!(s.sof.is_some());
}
struct ProgTestBitWriter {
out: Vec<u8>,
bits: u32,
nbits: u32,
}
impl ProgTestBitWriter {
fn new() -> Self {
Self {
out: Vec::new(),
bits: 0,
nbits: 0,
}
}
fn put(&mut self, val: u32, n: u32) {
self.bits = (self.bits << n) | (val & ((1u32 << n) - 1));
self.nbits += n;
while self.nbits >= 8 {
self.nbits -= 8;
let b = ((self.bits >> self.nbits) & 0xFF) as u8;
self.out.push(b);
if b == 0xFF {
self.out.push(0x00);
}
}
}
fn finish(mut self) -> Vec<u8> {
if self.nbits > 0 {
let pad = 8 - self.nbits;
let b = (((self.bits << pad) | ((1u32 << pad) - 1)) & 0xFF) as u8;
self.out.push(b);
if b == 0xFF {
self.out.push(0x00);
}
}
self.out
}
}
#[test]
fn ac_first_single_coef_then_eob() {
use crate::jpeg::huffman::{STD_AC_LUMA_BITS, STD_AC_LUMA_VALS};
let ac = HuffTable::build(&STD_AC_LUMA_BITS, &STD_AC_LUMA_VALS).unwrap();
let mut bw = ProgTestBitWriter::new();
let c = ac.encode[0x01];
bw.put(c.code as u32, c.len as u32);
bw.put(1, 1); let c0 = ac.encode[0x00];
bw.put(c0.code as u32, c0.len as u32);
let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut block = [0i32; 64];
let mut eob = 0u32;
prog_decode_ac_first(&mut br, &ac, &mut block, 1, 63, 0, &mut eob).unwrap();
assert_eq!(block[ZIGZAG[1]], 1);
for i in 0..64 {
if i != ZIGZAG[1] {
assert_eq!(block[i], 0, "unexpected nonzero at {i}");
}
}
}
#[test]
fn ac_first_pass_shifts_by_al() {
use crate::jpeg::huffman::{STD_AC_LUMA_BITS, STD_AC_LUMA_VALS};
let ac = HuffTable::build(&STD_AC_LUMA_BITS, &STD_AC_LUMA_VALS).unwrap();
let mut bw = ProgTestBitWriter::new();
let c = ac.encode[0x01];
bw.put(c.code as u32, c.len as u32);
bw.put(1, 1); let c0 = ac.encode[0x00];
bw.put(c0.code as u32, c0.len as u32);
let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut block = [0i32; 64];
let mut eob = 0u32;
prog_decode_ac_first(&mut br, &ac, &mut block, 1, 63, 2, &mut eob).unwrap();
assert_eq!(block[ZIGZAG[1]], 4); }
#[test]
fn ac_refine_extends_existing_coef() {
use crate::jpeg::huffman::{STD_AC_LUMA_BITS, STD_AC_LUMA_VALS};
let ac = HuffTable::build(&STD_AC_LUMA_BITS, &STD_AC_LUMA_VALS).unwrap();
let mut bw = ProgTestBitWriter::new();
let c0 = ac.encode[0x00];
bw.put(c0.code as u32, c0.len as u32);
bw.put(1, 1); let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut block = [0i32; 64];
block[ZIGZAG[1]] = 4; let mut eob = 0u32;
prog_decode_ac_refine(&mut br, &ac, &mut block, 1, 63, 1, &mut eob).unwrap();
assert_eq!(block[ZIGZAG[1]], 6);
}
}
#[cfg(test)]
mod non_interleaved_tests {
use super::*;
use crate::encoder::{encode_jpeg, encode_jpeg_non_interleaved};
use oxideav_core::frame::VideoPlane;
fn make_frame(w: u32, h: u32, pix: PixelFormat) -> VideoFrame {
let (cw, ch) = match pix {
PixelFormat::Yuv444P => (w, h),
PixelFormat::Yuv422P => (w.div_ceil(2), h),
PixelFormat::Yuv420P => (w.div_ceil(2), h.div_ceil(2)),
_ => panic!("unsupported"),
};
let mut y = vec![0u8; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
y[j * w as usize + i] = (((i + j * 3) * 7) % 255) as u8;
}
}
let mut cb = vec![0u8; (cw * ch) as usize];
let mut cr = vec![0u8; (cw * ch) as usize];
for j in 0..ch as usize {
for i in 0..cw as usize {
cb[j * cw as usize + i] = ((128 + i as i32 / 2) as u8).clamp(0, 255);
cr[j * cw as usize + i] = ((128 + j as i32 / 2) as u8).clamp(0, 255);
}
}
VideoFrame {
pts: Some(0),
planes: vec![
VideoPlane {
stride: w as usize,
data: y,
},
VideoPlane {
stride: cw as usize,
data: cb,
},
VideoPlane {
stride: cw as usize,
data: cr,
},
],
}
}
fn assert_matches_interleaved(w: u32, h: u32, pix: PixelFormat) {
let frame = make_frame(w, h, pix);
let base = encode_jpeg(&frame, w, h, pix, 75).expect("interleaved encode");
let non =
encode_jpeg_non_interleaved(&frame, w, h, pix, 75).expect("non-interleaved encode");
let sos_count = non.windows(2).filter(|w| w == &[0xFF, 0xDA]).count();
assert_eq!(sos_count, 3, "expected 3 SOS segments");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec_a = make_decoder(&dec_params).unwrap();
let mut dec_b = make_decoder(&dec_params).unwrap();
dec_a
.send_packet(&Packet::new(0, TimeBase::new(1, 30), base))
.unwrap();
dec_b
.send_packet(&Packet::new(0, TimeBase::new(1, 30), non))
.unwrap();
let Frame::Video(va) = dec_a.receive_frame().unwrap() else {
panic!()
};
let Frame::Video(vb) = dec_b.receive_frame().unwrap() else {
panic!()
};
assert_eq!(va.planes.len(), vb.planes.len());
for (pi, (pa, pb)) in va.planes.iter().zip(vb.planes.iter()).enumerate() {
assert_eq!(
pa.data, pb.data,
"plane {pi} mismatch between interleaved and non-interleaved decodes"
);
}
}
#[test]
fn non_interleaved_yuv420p_matches_interleaved() {
assert_matches_interleaved(32, 16, PixelFormat::Yuv420P);
}
#[test]
fn non_interleaved_yuv422p_matches_interleaved() {
assert_matches_interleaved(24, 24, PixelFormat::Yuv422P);
}
#[test]
fn non_interleaved_yuv444p_matches_interleaved() {
assert_matches_interleaved(16, 16, PixelFormat::Yuv444P);
}
}
#[cfg(test)]
mod cmyk_tests {
use super::*;
use crate::encoder::encode_jpeg_cmyk_1111;
fn psnr(a: &[u8], b: &[u8]) -> f64 {
assert_eq!(a.len(), b.len());
let mut sse: f64 = 0.0;
for i in 0..a.len() {
let d = a[i] as f64 - b[i] as f64;
sse += d * d;
}
if sse == 0.0 {
return 99.0;
}
20.0 * (255.0 / (sse / a.len() as f64).sqrt()).log10()
}
fn make_cmyk_planes(w: usize, h: usize) -> [Vec<u8>; 4] {
let mut c = vec![0u8; w * h];
let mut m = vec![0u8; w * h];
let mut y = vec![0u8; w * h];
let mut k = vec![0u8; w * h];
for j in 0..h {
for i in 0..w {
c[j * w + i] = ((i * 255 / w.max(1)) as u32).min(255) as u8;
m[j * w + i] = ((j * 255 / h.max(1)) as u32).min(255) as u8;
y[j * w + i] = (((i + j) * 255 / (w + h).max(1)) as u32).min(255) as u8;
k[j * w + i] = ((((i ^ j) * 7) & 0xFF) as u8) / 2;
}
}
[c, m, y, k]
}
#[test]
fn cmyk_plain_roundtrip() {
let w = 32u32;
let h = 16u32;
let planes = make_cmyk_planes(w as usize, h as usize);
let refs: [&[u8]; 4] = [&planes[0], &planes[1], &planes[2], &planes[3]];
let strides = [w as usize; 4];
let data =
encode_jpeg_cmyk_1111(w, h, &refs, &strides, 90, None).expect("encode plain CMYK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, (w * 4) as usize);
for (ci, src) in planes.iter().enumerate() {
let mut got = Vec::with_capacity(src.len());
for j in 0..h as usize {
for i in 0..w as usize {
got.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + ci]);
}
}
let p = psnr(src, &got);
assert!(p >= 30.0, "component {ci} PSNR too low: {p:.2}");
}
}
#[test]
fn cmyk_adobe_inverted_roundtrip() {
let w = 16u32;
let h = 16u32;
let planes = make_cmyk_planes(w as usize, h as usize);
let refs: [&[u8]; 4] = [&planes[0], &planes[1], &planes[2], &planes[3]];
let strides = [w as usize; 4];
let data =
encode_jpeg_cmyk_1111(w, h, &refs, &strides, 90, Some(0)).expect("encode Adobe CMYK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
for (ci, src) in planes.iter().enumerate() {
let mut got = Vec::with_capacity(src.len());
for j in 0..h as usize {
for i in 0..w as usize {
got.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + ci]);
}
}
let p = psnr(src, &got);
assert!(p >= 30.0, "Adobe CMYK component {ci} PSNR too low: {p:.2}");
}
}
#[test]
fn ycck_roundtrip_k_plane_matches() {
let w = 16u32;
let h = 16u32;
let mut yp = vec![0u8; (w * h) as usize];
let mut cb = vec![128u8; (w * h) as usize];
let mut cr = vec![128u8; (w * h) as usize];
let mut k = vec![0u8; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
let idx = j * w as usize + i;
yp[idx] = 128;
cb[idx] = 128;
cr[idx] = 128;
k[idx] = ((i * 255 / (w as usize - 1).max(1)) as u32).min(255) as u8;
}
}
let refs: [&[u8]; 4] = [&yp, &cb, &cr, &k];
let strides = [w as usize; 4];
let data = encode_jpeg_cmyk_1111(w, h, &refs, &strides, 90, Some(2)).expect("encode YCCK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
let mut got_k = Vec::with_capacity((w * h) as usize);
for j in 0..h as usize {
for i in 0..w as usize {
got_k.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + 3]);
}
}
let p = psnr(&k, &got_k);
assert!(p >= 30.0, "YCCK K plane PSNR too low: {p:.2}");
}
}
#[cfg(test)]
mod precision_12_tests {
use super::*;
use crate::encoder::encode_grayscale_jpeg_12bit;
#[test]
fn gray_12bit_roundtrip() {
let w = 16u32;
let h = 16u32;
let mut samples = vec![0u16; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
samples[j * w as usize + i] = 2000 + ((i + j) as u16);
}
}
let stride = w as usize;
let data =
encode_grayscale_jpeg_12bit(w, h, &samples, stride, 90).expect("encode 12-bit gray");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, (w * 2) as usize);
let mut got = Vec::with_capacity((w * h) as usize);
for j in 0..h as usize {
for i in 0..w as usize {
let o = j * v.planes[0].stride + i * 2;
got.push(v.planes[0].data[o] as u16 | ((v.planes[0].data[o + 1] as u16) << 8));
}
}
for (orig, dec) in samples.iter().zip(got.iter()) {
let diff = (*orig as i32 - *dec as i32).abs();
assert!(diff < 16, "12-bit roundtrip diff too large: {diff}");
}
}
}
#[cfg(test)]
mod lossless_tests {
use super::*;
use crate::encoder::encode_lossless_grayscale_jpeg_8bit;
#[test]
fn lossless_8bit_gray_exact_roundtrip() {
let w = 24u32;
let h = 16u32;
let mut samples = vec![0u8; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
samples[j * w as usize + i] =
((i as i32 * 3 + j as i32 * 5 + ((i ^ j) as i32 & 7)) & 0xFF) as u8;
}
}
let data = encode_lossless_grayscale_jpeg_8bit(w, h, &samples, w as usize)
.expect("encode lossless");
assert!(
data.windows(2).any(|x| x == [0xFF, 0xC3]),
"SOF3 marker missing from lossless output"
);
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w as usize);
for j in 0..h as usize {
for i in 0..w as usize {
let got = v.planes[0].data[j * v.planes[0].stride + i];
let want = samples[j * w as usize + i];
assert_eq!(got, want, "mismatch at ({i},{j})");
}
}
}
#[test]
fn hierarchical_arithmetic_still_rejected() {
let bytes = vec![
0xFF, 0xD8, 0xFF, 0xC5, 0x00, 0x08, 0x08, 0x00, 0x08, 0x00, 0x08, 0x01, 0xFF, 0xD9, ];
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(8);
dec_params.height = Some(8);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), bytes))
.unwrap();
let err = dec.receive_frame().expect_err("expected decode error");
assert!(
matches!(err, Error::Unsupported(_)),
"expected Unsupported, got {err:?}"
);
}
}