use crate::error::{MjpegError as Error, Result};
#[cfg(feature = "registry")]
use oxideav_core::frame::VideoPlane;
#[cfg(feature = "registry")]
use oxideav_core::{PixelFormat, VideoFrame};
#[cfg(not(feature = "registry"))]
use crate::image::{
MjpegFrame as VideoFrame, MjpegPixelFormat as PixelFormat, MjpegPlane as VideoPlane,
};
#[cfg(feature = "registry")]
pub use crate::registry::make_decoder;
use crate::jpeg::arith::{
decode_ac as arith_decode_ac, decode_ac_refine as arith_decode_ac_refine,
decode_dc_diff as arith_decode_dc_diff, decode_fixed_bit as arith_decode_fixed_bit,
decode_lossless_diff as arith_decode_lossless_diff, AcRefineStats, AcStats, ArithDecoder,
DcStats, LosslessStats,
};
use crate::jpeg::dct::idct8x8;
use crate::jpeg::huffman::{parse_dht, HuffTable};
use crate::jpeg::markers::{self, *};
use crate::jpeg::parser::{
parse_dac, parse_dnl, parse_dri, parse_sof, parse_sos, MarkerWalker, SofInfo, SosInfo,
};
use crate::jpeg::quant::{parse_dqt, QuantTable};
use crate::jpeg::zigzag::ZIGZAG;
struct JpegState {
quant: [Option<QuantTable>; 4],
dc_huff: [Option<HuffTable>; 4],
ac_huff: [Option<HuffTable>; 4],
restart_interval: u16,
sof: Option<SofInfo>,
progressive: bool,
seq_accum: bool,
lossless: bool,
lossless_arith: bool,
adobe_transform: Option<u8>,
arithmetic: bool,
progressive_arith: bool,
arith_dc: [Option<ArithDcConditioning>; 4],
arith_ac: [Option<ArithAcConditioning>; 4],
}
#[derive(Clone, Debug)]
struct ArithDcConditioning {
pub l: u8,
pub u: u8,
}
#[derive(Clone, Debug)]
struct ArithAcConditioning {
pub kx: u8,
}
impl JpegState {
fn new() -> Self {
Self {
quant: Default::default(),
dc_huff: Default::default(),
ac_huff: Default::default(),
restart_interval: 0,
sof: None,
progressive: false,
seq_accum: false,
lossless: false,
lossless_arith: false,
adobe_transform: None,
arithmetic: false,
progressive_arith: false,
arith_dc: Default::default(),
arith_ac: Default::default(),
}
}
}
const MAX_PIXEL_BUDGET: u64 = 64 * 1024 * 1024;
fn validate_sof(sof: &SofInfo) -> Result<()> {
if sof.components.is_empty() {
return Err(Error::invalid("SOF: Nf = 0"));
}
if sof.components.len() > 4 {
return Err(Error::unsupported("SOF: Nf > 4"));
}
for c in &sof.components {
if !(1..=4).contains(&c.h_factor) || !(1..=4).contains(&c.v_factor) {
return Err(Error::invalid("SOF: Hi/Vi outside 1..=4"));
}
if c.qt_id >= 4 {
return Err(Error::invalid("SOF: Tq > 3"));
}
}
let w = sof.width as u64;
let h = sof.height as u64;
let nf = sof.components.len() as u64;
if w.saturating_mul(h).saturating_mul(nf) > MAX_PIXEL_BUDGET {
return Err(Error::unsupported("SOF: pixel budget exceeded"));
}
Ok(())
}
fn validate_lossless_sof(sof: &SofInfo) -> Result<()> {
if !(2..=16).contains(&sof.precision) {
return Err(Error::unsupported(format!(
"lossless JPEG: precision {} out of range 2..=16",
sof.precision
)));
}
if !matches!(sof.components.len(), 1 | 3 | 4) {
return Err(Error::unsupported(format!(
"lossless JPEG: {} component(s) — only 1 (grayscale), 3 (RGB-class) and 4 (CMYK-class) are supported",
sof.components.len()
)));
}
if sof.components.len() == 4 && sof.precision != 8 {
return Err(Error::unsupported(format!(
"lossless JPEG: 4-component scans require precision 8, got {}",
sof.precision
)));
}
if sof.components.len() > 1 {
for c in &sof.components {
if c.h_factor != 1 || c.v_factor != 1 {
return Err(Error::unsupported(
"lossless JPEG: multi-component scans require H_i = V_i = 1",
));
}
}
}
Ok(())
}
fn validate_sos(sos: &SosInfo) -> Result<()> {
if sos.components.is_empty() {
return Err(Error::invalid("SOS: Ns = 0"));
}
if sos.components.len() > 4 {
return Err(Error::invalid("SOS: Ns > 4"));
}
for sc in &sos.components {
if sc.dc_table >= 4 {
return Err(Error::invalid("SOS: Tdj > 3"));
}
if sc.ac_table >= 4 {
return Err(Error::invalid("SOS: Taj > 3"));
}
}
Ok(())
}
fn resolve_dnl_height(data: &[u8]) -> Result<Option<u16>> {
let mut walker = MarkerWalker::new(data);
loop {
let Some(marker) = walker.next_marker()? else {
return Ok(None);
};
match marker {
markers::SOI => continue,
m if markers::is_rst(m) => continue,
markers::EOI => return Ok(None),
0xC5..=0xC7 | 0xCD..=0xCF => return Ok(None),
m if markers::is_sof(m) => {
let Ok(p) = walker.read_segment_payload() else {
return Ok(None);
};
let Ok(sof) = parse_sof(p) else {
return Ok(None);
};
if sof.height != 0 {
return Ok(None);
}
break;
}
markers::SOS => {
return Ok(None);
}
_ => {
let _ = walker.read_segment_payload()?;
}
}
}
loop {
let Some(marker) = walker.next_marker()? else {
return Err(Error::invalid(
"JPEG: SOF Y = 0 but stream ended before the first scan",
));
};
match marker {
markers::SOI => continue,
m if markers::is_rst(m) => continue,
markers::EOI => {
return Err(Error::invalid(
"JPEG: SOF Y = 0 but no scan precedes EOI (DNL required)",
));
}
markers::SOS => {
let _ = walker.read_segment_payload()?;
let _ = walker.read_scan_data()?;
break;
}
_ => {
let _ = walker.read_segment_payload()?;
}
}
}
let Some(marker) = walker.next_marker()? else {
return Err(Error::invalid(
"JPEG: SOF Y = 0 but no DNL marker follows the first scan",
));
};
if marker != markers::DNL {
return Err(Error::invalid(
"JPEG: SOF Y = 0 but the marker after the first scan is not DNL",
));
}
let p = walker.read_segment_payload()?;
let nl = parse_dnl(p)?;
Ok(Some(nl))
}
fn apply_dnl_height(sof: &mut SofInfo, dnl_height: Option<u16>) {
if sof.height == 0 {
if let Some(nl) = dnl_height {
sof.height = nl;
}
}
}
pub(crate) fn decode_jpeg(data: &[u8], pts: Option<i64>) -> Result<VideoFrame> {
if data.len() < 2 || data[0] != 0xFF || data[1] != markers::SOI {
return Err(Error::invalid("JPEG: missing SOI"));
}
let dnl_height = resolve_dnl_height(&data[2..])?;
let mut walker = MarkerWalker::new(&data[2..]);
let mut state = JpegState::new();
let mut coef_buf: Vec<Vec<[i32; 64]>> = Vec::new();
loop {
let Some(marker) = walker.next_marker()? else {
return Err(Error::invalid("JPEG: unexpected EOF before EOI"));
};
match marker {
EOI => {
if state.progressive
|| state.seq_accum
|| state.arithmetic
|| state.progressive_arith
{
return render_from_coefs(&state, &coef_buf, pts);
}
return Err(Error::invalid("JPEG: EOI before SOS"));
}
SOI => continue,
m if markers::is_rst(m) => continue,
DQT => {
let p = walker.read_segment_payload()?;
parse_dqt(p, &mut state.quant)?;
}
DHT => {
let p = walker.read_segment_payload()?;
parse_dht(p, &mut state.dc_huff, &mut state.ac_huff)?;
}
DAC => {
let p = walker.read_segment_payload()?;
let entries = parse_dac(p)?;
for e in entries {
if e.tc == 0 {
let l = e.cs & 0x0F;
let u = e.cs >> 4;
if l > u || u > 15 {
return Err(Error::invalid("DAC: invalid L/U bounds"));
}
state.arith_dc[e.tb as usize] = Some(ArithDcConditioning { l, u });
} else {
state.arith_ac[e.tb as usize] = Some(ArithAcConditioning { kx: e.cs });
}
}
}
DRI => {
let p = walker.read_segment_payload()?;
state.restart_interval = parse_dri(p)?;
}
SOF0 | SOF1 => {
if state.sof.is_some() {
return Err(Error::invalid("JPEG: multiple SOF segments"));
}
let p = walker.read_segment_payload()?;
let mut sof = parse_sof(p)?;
apply_dnl_height(&mut sof, dnl_height);
validate_sof(&sof)?;
state.sof = Some(sof);
}
SOF2 => {
if state.sof.is_some() {
return Err(Error::invalid("JPEG: multiple SOF segments"));
}
let p = walker.read_segment_payload()?;
let mut sof = parse_sof(p)?;
apply_dnl_height(&mut sof, dnl_height);
validate_sof(&sof)?;
if sof.precision != 8 && sof.precision != 12 {
return Err(Error::unsupported(format!(
"progressive JPEG: precision {} (only 8 and 12 are supported)",
sof.precision
)));
}
if sof.components.len() == 4 && sof.precision != 8 {
return Err(Error::unsupported(
"progressive JPEG: 4-component scans only at P = 8",
));
}
coef_buf = init_coef_buffers(&sof)?;
state.sof = Some(sof);
state.progressive = true;
}
SOF3 | markers::SOF11 => {
if state.sof.is_some() {
return Err(Error::invalid("JPEG: multiple SOF segments"));
}
let p = walker.read_segment_payload()?;
let mut sof = parse_sof(p)?;
apply_dnl_height(&mut sof, dnl_height);
validate_sof(&sof)?;
validate_lossless_sof(&sof)?;
state.sof = Some(sof);
state.lossless = true;
state.lossless_arith = marker == markers::SOF11;
}
SOF9 => {
if state.sof.is_some() {
return Err(Error::invalid("JPEG: multiple SOF segments"));
}
let p = walker.read_segment_payload()?;
let mut sof = parse_sof(p)?;
apply_dnl_height(&mut sof, dnl_height);
validate_sof(&sof)?;
if sof.precision != 8 {
return Err(Error::unsupported(format!(
"arithmetic JPEG: precision {} (only 8 is supported)",
sof.precision
)));
}
if sof.components.len() > 3 {
return Err(Error::unsupported(
"arithmetic JPEG: 4+ component scans not supported",
));
}
coef_buf = init_coef_buffers(&sof)?;
state.sof = Some(sof);
state.arithmetic = true;
}
markers::SOF10 => {
if state.sof.is_some() {
return Err(Error::invalid("JPEG: multiple SOF segments"));
}
let p = walker.read_segment_payload()?;
let mut sof = parse_sof(p)?;
apply_dnl_height(&mut sof, dnl_height);
validate_sof(&sof)?;
if sof.precision != 8 && sof.precision != 12 {
return Err(Error::unsupported(format!(
"progressive arithmetic JPEG: precision {} (only 8 and 12 are supported)",
sof.precision
)));
}
if sof.components.len() == 4 && sof.precision != 8 {
return Err(Error::unsupported(
"progressive arithmetic JPEG: 4-component scans only at P = 8",
));
}
coef_buf = init_coef_buffers(&sof)?;
state.sof = Some(sof);
state.progressive_arith = true;
}
0xC5..=0xC7 | 0xCD..=0xCF => {
let _ = walker.read_segment_payload();
return Err(Error::unsupported(
"JPEG: hierarchical and SOF13..15 arithmetic variants are not supported",
));
}
SOS => {
let p = walker.read_segment_payload()?;
let sos = parse_sos(p)?;
validate_sos(&sos)?;
let scan = walker.read_scan_data()?;
if state.lossless {
if state.lossless_arith {
return decode_lossless_arith_scan(&state, &sos, scan, pts);
}
return decode_lossless_scan(&state, &sos, scan, pts);
}
if state.progressive_arith {
decode_progressive_arith_scan(&state, &sos, scan, &mut coef_buf)?;
} else if state.arithmetic {
decode_arith_scan(&state, &sos, scan, &mut coef_buf)?;
} else if state.progressive {
decode_progressive_scan(&state, &sos, scan, &mut coef_buf)?;
} else {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
let fully_interleaved = sos.components.len() == sof.components.len();
let fast_path_ok = fully_interleaved
&& !state.seq_accum
&& sof.components.len() <= 3
&& sof.precision == 8;
if fast_path_ok {
return decode_scan(&state, &sos, scan, pts);
}
if !state.seq_accum {
coef_buf = init_coef_buffers(sof)?;
state.seq_accum = true;
}
decode_sequential_scan_accum(&state, &sos, scan, &mut coef_buf)?;
}
}
COM => {
let _ = walker.read_segment_payload()?;
}
markers::DNL => {
let _ = walker.read_segment_payload()?;
}
markers::APP14 => {
let p = walker.read_segment_payload()?;
if p.len() >= 12 && &p[0..5] == b"Adobe" {
state.adobe_transform = Some(p[11]);
}
}
m if markers::is_app(m) => {
let _ = walker.read_segment_payload()?;
}
_ => {
let _ = walker.read_segment_payload();
}
}
}
}
fn init_coef_buffers(sof: &SofInfo) -> Result<Vec<Vec<[i32; 64]>>> {
if sof.precision != 8 && sof.precision != 12 {
return Err(Error::unsupported(format!(
"coef accumulator: precision {} (only 8 and 12 are supported)",
sof.precision
)));
}
if sof.components.is_empty() {
return Err(Error::invalid("SOF: no components"));
}
if sof.components.len() > 4 {
return Err(Error::unsupported("coef accumulator: >4 components"));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1);
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1);
if h_max == 0 || v_max == 0 {
return Err(Error::invalid("SOF: sampling factor = 0"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let mcu_w_px = 8 * h_max as usize;
let mcu_h_px = 8 * v_max as usize;
let mcus_x = width.div_ceil(mcu_w_px);
let mcus_y = height.div_ceil(mcu_h_px);
let mut out = Vec::with_capacity(sof.components.len());
for c in &sof.components {
let blocks_x = mcus_x * c.h_factor as usize;
let blocks_y = mcus_y * c.v_factor as usize;
out.push(vec![[0i32; 64]; blocks_x * blocks_y]);
}
Ok(out)
}
struct BitReader<'a> {
buf: &'a [u8],
pos: usize,
bits: u32,
nbits: u32,
pub saw_rst: Option<u8>,
}
impl<'a> BitReader<'a> {
fn new(buf: &'a [u8]) -> Self {
Self {
buf,
pos: 0,
bits: 0,
nbits: 0,
saw_rst: None,
}
}
fn next_byte_with_stuff(&mut self) -> Result<Option<u8>> {
if self.pos >= self.buf.len() {
return Ok(None);
}
let b = self.buf[self.pos];
self.pos += 1;
if b == 0xFF {
while self.pos < self.buf.len() && self.buf[self.pos] == 0xFF {
self.pos += 1;
}
if self.pos >= self.buf.len() {
return Err(Error::invalid("scan: 0xFF at end without followup"));
}
let next = self.buf[self.pos];
self.pos += 1;
if next == 0x00 {
return Ok(Some(0xFF));
}
if markers::is_rst(next) {
self.saw_rst = Some(next);
return Ok(None);
}
self.pos -= 2;
return Ok(None);
}
Ok(Some(b))
}
fn fill(&mut self, needed: u32) -> Result<()> {
if needed > 24 {
return Err(Error::invalid("BitReader: requested > 24 bits"));
}
while self.nbits < needed {
match self.next_byte_with_stuff()? {
Some(b) => {
self.bits |= (b as u32) << (24 - self.nbits);
self.nbits += 8;
}
None => {
self.bits |= 0;
self.nbits = needed;
break;
}
}
}
Ok(())
}
fn get_bits(&mut self, n: u32) -> Result<u32> {
if n == 0 {
return Ok(0);
}
if n > 24 {
return Err(Error::invalid("BitReader: get_bits(n > 24)"));
}
self.fill(n)?;
let v = self.bits >> (32 - n);
self.bits <<= n;
self.nbits -= n;
Ok(v)
}
fn reset_at_restart(&mut self) {
self.bits = 0;
self.nbits = 0;
self.saw_rst = None;
}
}
fn decode_huff(br: &mut BitReader<'_>, t: &HuffTable) -> Result<u8> {
let mut code: i32 = 0;
for l in 0..16 {
let bit = br.get_bits(1)? as i32;
code = (code << 1) | bit;
if code <= t.max_code[l] {
let idx = (t.val_offset[l] + code) as usize;
if idx >= t.values.len() {
return Err(Error::invalid("huffman: value index OOB"));
}
return Ok(t.values[idx]);
}
}
Err(Error::invalid("huffman: no matching code (length > 16)"))
}
fn extend(value: i32, size: u32) -> i32 {
if size == 0 {
return 0;
}
let vt = 1 << (size - 1);
if value < vt {
value - ((1 << size) - 1)
} else {
value
}
}
fn decode_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
pts: Option<i64>,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sof.precision != 8 {
return Err(Error::unsupported("precision != 8"));
}
if sof.components.is_empty() {
return Err(Error::invalid("SOF: no components"));
}
if sof.components.len() > 3 {
return Err(Error::unsupported("4+ components"));
}
if sos.components.len() != sof.components.len() {
return Err(Error::unsupported("non-interleaved scan"));
}
let n_comp = sof.components.len();
let grayscale = n_comp == 1;
let is_rgb = detect_rgb_3comp(sof, state.adobe_transform);
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1);
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1);
if h_max == 0 || v_max == 0 {
return Err(Error::invalid("SOF: sampling factor = 0"));
}
let pix_fmt = if grayscale {
PixelFormat::Gray8
} else if is_rgb {
for c in &sof.components {
if c.h_factor != 1 || c.v_factor != 1 {
return Err(Error::unsupported(
"RGB baseline JPEG: every component must declare H = V = 1",
));
}
}
PixelFormat::Rgb24
} else if n_comp == 3 {
let y = sof.components[0];
let cb = sof.components[1];
let cr = sof.components[2];
if cb.h_factor != cr.h_factor || cb.v_factor != cr.v_factor {
return Err(Error::unsupported(
"chroma components have different sampling factors",
));
}
if cb.h_factor != 1 || cb.v_factor != 1 {
return Err(Error::unsupported(
"chroma components must have factor 1 (luma carries the oversampling)",
));
}
match (y.h_factor, y.v_factor) {
(1, 1) => PixelFormat::Yuv444P,
(2, 1) => PixelFormat::Yuv422P,
(2, 2) => PixelFormat::Yuv420P,
(4, 1) => PixelFormat::Yuv411P,
_ => {
return Err(Error::unsupported(format!(
"luma sampling {}x{}",
y.h_factor, y.v_factor
)))
}
}
} else {
return Err(Error::unsupported("2-component JPEG"));
};
let width = sof.width as usize;
let height = sof.height as usize;
let mcu_w_px = 8 * h_max as usize;
let mcu_h_px = 8 * v_max as usize;
let mcus_x = width.div_ceil(mcu_w_px);
let mcus_y = height.div_ceil(mcu_h_px);
let mut comp_buf: Vec<Vec<u8>> = Vec::with_capacity(n_comp);
let mut comp_stride: Vec<usize> = Vec::with_capacity(n_comp);
let mut comp_w_full: Vec<usize> = Vec::with_capacity(n_comp);
let mut comp_h_full: Vec<usize> = Vec::with_capacity(n_comp);
for c in &sof.components {
let w_full = mcus_x * 8 * c.h_factor as usize;
let h_full = mcus_y * 8 * c.v_factor as usize;
comp_buf.push(vec![0u8; w_full * h_full]);
comp_stride.push(w_full);
comp_w_full.push(w_full);
comp_h_full.push(h_full);
}
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("SOS: component id not in SOF"))
})
.collect::<Result<Vec<_>>>()?;
let dc_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.dc_huff[sc.dc_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: DC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let ac_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.ac_huff[sc.ac_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: AC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let quant_tables: Vec<&QuantTable> = sof
.components
.iter()
.map(|c| {
state.quant[c.qt_id as usize]
.as_ref()
.ok_or_else(|| Error::invalid("quant table missing for component"))
})
.collect::<Result<Vec<_>>>()?;
let mut br = BitReader::new(scan);
let mut prev_dc = vec![0i32; n_comp];
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
for my in 0..mcus_y {
for mx in 0..mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {
}
None => {
if br.saw_rst.is_none() {
if prev == br.pos {
break;
}
}
break;
}
}
}
if let Some(m) = br.saw_rst {
if m != expected_rst {
}
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
for p in prev_dc.iter_mut() {
*p = 0;
}
br.reset_at_restart();
}
}
for (sidx, sof_idx) in sos_map.iter().enumerate() {
let c = sof.components[*sof_idx];
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let mut block = [0i32; 64];
decode_block(
&mut br,
dc_tables[sidx],
ac_tables[sidx],
&mut prev_dc[*sof_idx],
&mut block,
)?;
let qt = quant_tables[*sof_idx];
let mut fblock = [0.0f32; 64];
for k in 0..64 {
fblock[k] = block[k] as f32 * qt.values[k] as f32;
}
idct8x8(&mut fblock);
let dst_x0 = mx * 8 * c.h_factor as usize + bx * 8;
let dst_y0 = my * 8 * c.v_factor as usize + by * 8;
let stride = comp_stride[*sof_idx];
let buf = &mut comp_buf[*sof_idx];
for j in 0..8 {
for i in 0..8 {
let v = fblock[j * 8 + i] + 128.0;
let px = if v <= 0.0 {
0
} else if v >= 255.0 {
255
} else {
v.round() as u8
};
buf[(dst_y0 + j) * stride + dst_x0 + i] = px;
}
}
}
}
}
mcus_since_restart += 1;
}
}
let out_format = pix_fmt;
let mut planes: Vec<VideoPlane> = Vec::new();
match out_format {
PixelFormat::Gray8 => {
let stride = width;
let mut data = vec![0u8; stride * height];
let src_stride = comp_stride[0];
for y in 0..height {
data[y * stride..y * stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride..y * src_stride + width]);
}
planes.push(VideoPlane { stride, data });
}
PixelFormat::Rgb24 => {
let stride = width * 3;
let mut data = vec![0u8; stride * height];
let src_strides = [comp_stride[0], comp_stride[1], comp_stride[2]];
for y in 0..height {
let off = y * stride;
for x in 0..width {
data[off + x * 3] = comp_buf[0][y * src_strides[0] + x];
data[off + x * 3 + 1] = comp_buf[1][y * src_strides[1] + x];
data[off + x * 3 + 2] = comp_buf[2][y * src_strides[2] + x];
}
}
planes.push(VideoPlane { stride, data });
}
PixelFormat::Yuv444P
| PixelFormat::Yuv422P
| PixelFormat::Yuv420P
| PixelFormat::Yuv411P => {
let (c_w, c_h) = match out_format {
PixelFormat::Yuv444P => (width, height),
PixelFormat::Yuv422P => (width.div_ceil(2), height),
PixelFormat::Yuv420P => (width.div_ceil(2), height.div_ceil(2)),
PixelFormat::Yuv411P => (width.div_ceil(4), height),
_ => unreachable!(),
};
let y_stride = width;
let mut y_data = vec![0u8; y_stride * height];
let src_stride_y = comp_stride[0];
for y in 0..height {
y_data[y * y_stride..y * y_stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride_y..y * src_stride_y + width]);
}
planes.push(VideoPlane {
stride: y_stride,
data: y_data,
});
for ci in [1usize, 2] {
let src_stride = comp_stride[ci];
let stride = c_w;
let mut data = vec![0u8; stride * c_h];
for y in 0..c_h {
data[y * stride..y * stride + c_w]
.copy_from_slice(&comp_buf[ci][y * src_stride..y * src_stride + c_w]);
}
planes.push(VideoPlane { stride, data });
}
}
_ => unreachable!(),
}
Ok(VideoFrame { pts, planes })
}
fn detect_rgb_3comp(sof: &SofInfo, adobe_transform: Option<u8>) -> bool {
if sof.components.len() != 3 {
return false;
}
if adobe_transform == Some(0) {
return true;
}
let ids: [u8; 3] = [
sof.components[0].id,
sof.components[1].id,
sof.components[2].id,
];
ids == [b'R', b'G', b'B']
}
fn decode_block(
br: &mut BitReader<'_>,
dc: &HuffTable,
ac: &HuffTable,
prev_dc: &mut i32,
out_natural: &mut [i32; 64],
) -> Result<()> {
let t = decode_huff(br, dc)? as u32;
let dc_diff = if t == 0 {
0
} else {
let bits = br.get_bits(t)? as i32;
extend(bits, t)
};
*prev_dc = prev_dc.wrapping_add(dc_diff);
out_natural[0] = *prev_dc;
let mut k: usize = 1;
while k < 64 {
let rs = decode_huff(br, ac)?;
let run = (rs >> 4) as usize;
let size = (rs & 0x0F) as u32;
if size == 0 {
if run == 15 {
k += 16;
continue;
}
break;
}
k += run;
if k >= 64 {
return Err(Error::invalid("JPEG AC: run out of block"));
}
let bits = br.get_bits(size)? as i32;
let val = extend(bits, size);
out_natural[ZIGZAG[k]] = val;
k += 1;
}
Ok(())
}
fn decode_sequential_scan_accum(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
coefs: &mut [Vec<[i32; 64]>],
) -> Result<()> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sof.precision != 8 && sof.precision != 12 {
return Err(Error::unsupported(format!(
"scan: unsupported precision {}",
sof.precision
)));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
if h_max == 0 || v_max == 0 {
return Err(Error::invalid("SOF: sampling factor = 0"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("SOS: component id not in SOF"))
})
.collect::<Result<Vec<_>>>()?;
let dc_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.dc_huff[sc.dc_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: DC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let ac_tables: Vec<&HuffTable> = sos
.components
.iter()
.map(|sc| {
state.ac_huff[sc.ac_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("SOS: AC Huffman table missing"))
})
.collect::<Result<Vec<_>>>()?;
let interleaved = sos.components.len() == sof.components.len();
let (scan_mcus_x, scan_mcus_y) = if interleaved {
(mcus_x, mcus_y)
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
};
let mut br = BitReader::new(scan);
let mut prev_dc = vec![0i32; sos.components.len()];
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
for my in 0..scan_mcus_y {
for mx in 0..scan_mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {}
None => {
if prev == br.pos {
break;
}
break;
}
}
}
if br.saw_rst.is_some() {
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
for p in prev_dc.iter_mut() {
*p = 0;
}
br.reset_at_restart();
}
}
if interleaved {
for (sidx, &sof_idx) in sos_map.iter().enumerate() {
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let bidx_x = mx * c.h_factor as usize + bx;
let bidx_y = my * c.v_factor as usize + by;
let bi = bidx_y * blocks_x + bidx_x;
decode_block(
&mut br,
dc_tables[sidx],
ac_tables[sidx],
&mut prev_dc[sidx],
&mut coefs[sof_idx][bi],
)?;
}
}
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
decode_block(
&mut br,
dc_tables[0],
ac_tables[0],
&mut prev_dc[0],
&mut coefs[sof_idx][bi],
)?;
}
mcus_since_restart += 1;
}
}
Ok(())
}
fn decode_arith_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
coefs: &mut [Vec<[i32; 64]>],
) -> Result<()> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sof.precision != 8 {
return Err(Error::unsupported(format!(
"arithmetic scan: precision {} (only 8 supported)",
sof.precision
)));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
if h_max == 0 || v_max == 0 {
return Err(Error::invalid("SOF: sampling factor = 0"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("SOS: component id not in SOF"))
})
.collect::<Result<Vec<_>>>()?;
let interleaved = sos.components.len() == sof.components.len();
let n_comp = sof.components.len();
let mut dc_stats: Vec<DcStats> = (0..n_comp)
.map(|i| {
let mut s = DcStats::new();
if let Some(sc) = sos.components.iter().find(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.map(|j| j == i)
.unwrap_or(false)
}) {
if let Some(cond) = state.arith_dc[sc.dc_table as usize].as_ref() {
s.l = cond.l;
s.u = cond.u;
}
}
s
})
.collect();
let mut ac_stats: Vec<AcStats> = (0..n_comp)
.map(|i| {
let mut s = AcStats::new();
if let Some(sc) = sos.components.iter().find(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.map(|j| j == i)
.unwrap_or(false)
}) {
if let Some(cond) = state.arith_ac[sc.ac_table as usize].as_ref() {
s.kx = cond.kx;
}
}
s
})
.collect();
let mut scan_pos = 0usize;
let mut decoder = ArithDecoder::new(&scan[scan_pos..]);
let mut mcus_since_restart: u32 = 0;
let (scan_mcus_x, scan_mcus_y) = if interleaved {
(mcus_x, mcus_y)
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
};
for my in 0..scan_mcus_y {
for mx in 0..scan_mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
scan_pos = locate_next_marker_after(scan, scan_pos);
if scan_pos >= scan.len() {
return Err(Error::invalid(
"arithmetic scan: missing restart marker mid-scan",
));
}
for s in dc_stats.iter_mut() {
s.restart_reset();
}
for s in ac_stats.iter_mut() {
s.restart_reset();
}
decoder = ArithDecoder::new(&scan[scan_pos..]);
}
if interleaved {
for &sof_idx in sos_map.iter() {
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let bidx_x = mx * c.h_factor as usize + bx;
let bidx_y = my * c.v_factor as usize + by;
let bi = bidx_y * blocks_x + bidx_x;
decode_arith_block(
&mut decoder,
&mut dc_stats[sof_idx],
&mut ac_stats[sof_idx],
&mut coefs[sof_idx][bi],
)?;
}
}
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
decode_arith_block(
&mut decoder,
&mut dc_stats[sof_idx],
&mut ac_stats[sof_idx],
&mut coefs[sof_idx][bi],
)?;
}
mcus_since_restart += 1;
}
}
Ok(())
}
fn decode_arith_block(
d: &mut ArithDecoder<'_>,
dc: &mut DcStats,
ac: &mut AcStats,
block: &mut [i32; 64],
) -> Result<()> {
for v in block.iter_mut() {
*v = 0;
}
let diff = arith_decode_dc_diff(d, dc)?;
dc.pred = dc.pred.wrapping_add(diff);
block[0] = dc.pred;
arith_decode_ac(d, ac, block, 1, 63)?;
Ok(())
}
fn locate_next_marker_after(scan: &[u8], from: usize) -> usize {
let mut i = from;
while i + 1 < scan.len() {
if scan[i] == 0xFF && scan[i + 1] != 0x00 {
let mut j = i + 1;
while j < scan.len() && scan[j] == 0xFF {
j += 1;
}
if j < scan.len() {
return j + 1; }
return scan.len();
}
i += 1;
}
scan.len()
}
fn decode_progressive_arith_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
coefs: &mut [Vec<[i32; 64]>],
) -> Result<()> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sos.ss > 63 || sos.se > 63 || sos.ss > sos.se {
return Err(Error::invalid("progressive arith: invalid Ss/Se"));
}
let is_dc_scan = sos.ss == 0;
if is_dc_scan && sos.se != 0 {
return Err(Error::invalid("progressive arith: DC scan must have Se=0"));
}
if !is_dc_scan && sos.components.len() != 1 {
return Err(Error::invalid(
"progressive arith: AC scans must be non-interleaved",
));
}
if sos.ah > 13 || sos.al > 13 {
return Err(Error::invalid("progressive arith: Ah/Al out of range"));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
let width = sof.width as usize;
let height = sof.height as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("progressive arith SOS: unknown component"))
})
.collect::<Result<Vec<_>>>()?;
let mut dc_stats: Vec<DcStats> = sos
.components
.iter()
.map(|sc| {
let mut s = DcStats::new();
if let Some(cond) = state.arith_dc[sc.dc_table as usize].as_ref() {
s.l = cond.l;
s.u = cond.u;
}
s
})
.collect();
let mut ac_stats = AcStats::new();
if let Some(sc) = sos.components.first() {
if let Some(cond) = state.arith_ac[sc.ac_table as usize].as_ref() {
ac_stats.kx = cond.kx;
}
}
let mut ac_refine_stats = AcRefineStats::new();
let mut scan_pos = 0usize;
let mut decoder = ArithDecoder::new(scan);
let mut mcus_since_restart: u32 = 0;
let interleaved = is_dc_scan && sos.components.len() > 1;
let (scan_mcus_x, scan_mcus_y) = if interleaved {
(mcus_x, mcus_y)
} else {
let c = sof.components[sos_map[0]];
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
};
for my in 0..scan_mcus_y {
for mx in 0..scan_mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
scan_pos = locate_next_marker_after(scan, scan_pos);
if scan_pos >= scan.len() {
return Err(Error::invalid(
"progressive arith: missing restart marker mid-scan",
));
}
for s in dc_stats.iter_mut() {
s.restart_reset();
}
ac_stats.restart_reset();
ac_refine_stats.restart_reset();
decoder = ArithDecoder::new(&scan[scan_pos..]);
}
if is_dc_scan {
if interleaved {
for (sidx, &sof_idx) in sos_map.iter().enumerate() {
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let bidx_x = mx * c.h_factor as usize + bx;
let bidx_y = my * c.v_factor as usize + by;
let bi = bidx_y * blocks_x + bidx_x;
prog_arith_decode_dc(
&mut decoder,
&mut dc_stats[sidx],
&mut coefs[sof_idx][bi],
sos.ah,
sos.al,
)?;
}
}
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
prog_arith_decode_dc(
&mut decoder,
&mut dc_stats[0],
&mut coefs[sof_idx][bi],
sos.ah,
sos.al,
)?;
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
let ss = sos.ss as usize;
let se = sos.se as usize;
if sos.ah == 0 {
let mut tmp = [0i32; 64];
arith_decode_ac(&mut decoder, &mut ac_stats, &mut tmp, ss, se)?;
let block = &mut coefs[sof_idx][bi];
for k in ss..=se {
let pos = ZIGZAG[k];
if tmp[pos] != 0 {
block[pos] = tmp[pos] << sos.al;
}
}
} else {
arith_decode_ac_refine(
&mut decoder,
&mut ac_refine_stats,
&mut coefs[sof_idx][bi],
ss,
se,
sos.al,
)?;
}
}
mcus_since_restart += 1;
}
}
Ok(())
}
fn prog_arith_decode_dc(
d: &mut ArithDecoder<'_>,
dc: &mut DcStats,
block: &mut [i32; 64],
ah: u8,
al: u8,
) -> Result<()> {
if ah == 0 {
let diff = arith_decode_dc_diff(d, dc)?;
dc.pred = dc.pred.wrapping_add(diff);
block[0] = dc.pred << al;
} else {
let bit = arith_decode_fixed_bit(d) as i32;
block[0] |= bit << al;
}
Ok(())
}
fn decode_progressive_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
coefs: &mut [Vec<[i32; 64]>],
) -> Result<()> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sos.ss > 63 || sos.se > 63 || sos.ss > sos.se {
return Err(Error::invalid("progressive: invalid Ss/Se"));
}
let is_dc_scan = sos.ss == 0;
if is_dc_scan && sos.se != 0 {
return Err(Error::invalid("progressive: DC scan must have Se=0"));
}
if !is_dc_scan && sos.components.len() != 1 {
return Err(Error::invalid(
"progressive: AC scans must be non-interleaved",
));
}
if sos.ah > 13 || sos.al > 13 {
return Err(Error::invalid("progressive: Ah/Al out of range"));
}
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
let width = sof.width as usize;
let height = sof.height as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let sos_map: Vec<usize> = sos
.components
.iter()
.map(|sc| {
sof.components
.iter()
.position(|fc| fc.id == sc.id)
.ok_or_else(|| Error::invalid("progressive SOS: unknown component"))
})
.collect::<Result<Vec<_>>>()?;
let dc_tables: Vec<Option<&HuffTable>> = sos
.components
.iter()
.map(|sc| {
if is_dc_scan {
state.dc_huff[sc.dc_table as usize].as_ref()
} else {
None
}
})
.collect();
let ac_tables: Vec<Option<&HuffTable>> = sos
.components
.iter()
.map(|sc| {
if is_dc_scan {
None
} else {
state.ac_huff[sc.ac_table as usize].as_ref()
}
})
.collect();
if is_dc_scan && dc_tables.iter().any(|t| t.is_none()) {
return Err(Error::invalid("progressive DC: Huffman table missing"));
}
if !is_dc_scan && ac_tables[0].is_none() {
return Err(Error::invalid("progressive AC: Huffman table missing"));
}
let mut br = BitReader::new(scan);
let mut prev_dc = vec![0i32; sos.components.len()];
let mut eob_run: u32 = 0;
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
let (scan_mcus_x, scan_mcus_y) = if is_dc_scan && sos.components.len() > 1 {
(mcus_x, mcus_y)
} else {
let ci = sos_map[0];
let c = sof.components[ci];
if sos.components.len() == 1 && sof.components.len() == 1 {
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
} else {
(mcus_x * c.h_factor as usize, mcus_y * c.v_factor as usize)
}
};
for my in 0..scan_mcus_y {
for mx in 0..scan_mcus_x {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {}
None => {
if prev == br.pos {
break;
}
break;
}
}
}
if br.saw_rst.is_some() {
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
for p in prev_dc.iter_mut() {
*p = 0;
}
eob_run = 0;
br.reset_at_restart();
}
}
if is_dc_scan {
if sos.components.len() > 1 {
for (sidx, &sof_idx) in sos_map.iter().enumerate() {
let c = sof.components[sof_idx];
for by in 0..c.v_factor as usize {
for bx in 0..c.h_factor as usize {
let blocks_x = mcus_x * c.h_factor as usize;
let bidx_x = mx * c.h_factor as usize + bx;
let bidx_y = my * c.v_factor as usize + by;
let bi = bidx_y * blocks_x + bidx_x;
prog_decode_dc(
&mut br,
dc_tables[sidx].unwrap(),
&mut prev_dc[sidx],
&mut coefs[sof_idx][bi],
sos.ah,
sos.al,
)?;
}
}
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
prog_decode_dc(
&mut br,
dc_tables[0].unwrap(),
&mut prev_dc[0],
&mut coefs[sof_idx][bi],
sos.ah,
sos.al,
)?;
}
} else {
let sof_idx = sos_map[0];
let c = sof.components[sof_idx];
let blocks_x = mcus_x * c.h_factor as usize;
let bi = my * blocks_x + mx;
let ac = ac_tables[0].unwrap();
if sos.ah == 0 {
prog_decode_ac_first(
&mut br,
ac,
&mut coefs[sof_idx][bi],
sos.ss as usize,
sos.se as usize,
sos.al,
&mut eob_run,
)?;
} else {
prog_decode_ac_refine(
&mut br,
ac,
&mut coefs[sof_idx][bi],
sos.ss as usize,
sos.se as usize,
sos.al,
&mut eob_run,
)?;
}
}
mcus_since_restart += 1;
}
}
Ok(())
}
fn prog_decode_dc(
br: &mut BitReader<'_>,
dc: &HuffTable,
prev_dc: &mut i32,
block: &mut [i32; 64],
ah: u8,
al: u8,
) -> Result<()> {
if ah == 0 {
let t = decode_huff(br, dc)? as u32;
let dc_diff = if t == 0 {
0
} else {
let bits = br.get_bits(t)? as i32;
extend(bits, t)
};
*prev_dc = prev_dc.wrapping_add(dc_diff);
block[0] = *prev_dc << al;
} else {
let bit = br.get_bits(1)? as i32;
block[0] |= bit << al;
}
Ok(())
}
fn prog_decode_ac_first(
br: &mut BitReader<'_>,
ac: &HuffTable,
block: &mut [i32; 64],
ss: usize,
se: usize,
al: u8,
eob_run: &mut u32,
) -> Result<()> {
if *eob_run > 0 {
*eob_run -= 1;
return Ok(());
}
let mut k = ss;
while k <= se {
let rs = decode_huff(br, ac)?;
let r = (rs >> 4) as usize;
let s = (rs & 0x0F) as u32;
if s == 0 {
if r != 15 {
let extra = if r == 0 { 0 } else { br.get_bits(r as u32)? };
*eob_run = (1u32 << r) + extra - 1;
return Ok(());
}
k += 16;
} else {
k += r;
if k > se {
return Err(Error::invalid("progressive AC: run out of band"));
}
let bits = br.get_bits(s)? as i32;
let val = extend(bits, s) << al;
block[ZIGZAG[k]] = val;
k += 1;
}
}
Ok(())
}
fn prog_decode_ac_refine(
br: &mut BitReader<'_>,
ac: &HuffTable,
block: &mut [i32; 64],
ss: usize,
se: usize,
al: u8,
eob_run: &mut u32,
) -> Result<()> {
let p1: i32 = 1 << al;
let m1: i32 = -1 << al;
let mut k = ss;
if *eob_run == 0 {
while k <= se {
let rs = decode_huff(br, ac)?;
let mut r = (rs >> 4) as usize;
let s = (rs & 0x0F) as u32;
let new_val: i32;
if s == 0 {
if r != 15 {
let extra = if r == 0 { 0 } else { br.get_bits(r as u32)? };
*eob_run = (1u32 << r) + extra;
break;
}
new_val = 0;
} else if s == 1 {
let sign_bit = br.get_bits(1)?;
new_val = if sign_bit == 0 { m1 } else { p1 };
} else {
return Err(Error::invalid("progressive AC refine: bad s"));
}
loop {
if k > se {
return Err(Error::invalid("progressive AC refine: k past se"));
}
let pos = ZIGZAG[k];
if block[pos] != 0 {
let bit = br.get_bits(1)? as i32;
if bit != 0 && (block[pos] & p1) == 0 {
if block[pos] >= 0 {
block[pos] += p1;
} else {
block[pos] += m1;
}
}
} else if r == 0 {
break;
} else {
r -= 1;
}
k += 1;
}
if new_val != 0 && k <= se {
block[ZIGZAG[k]] = new_val;
}
k += 1;
}
}
if *eob_run > 0 {
while k <= se {
let pos = ZIGZAG[k];
if block[pos] != 0 {
let bit = br.get_bits(1)? as i32;
if bit != 0 && (block[pos] & p1) == 0 {
if block[pos] >= 0 {
block[pos] += p1;
} else {
block[pos] += m1;
}
}
}
k += 1;
}
*eob_run -= 1;
}
Ok(())
}
fn render_from_coefs(
state: &JpegState,
coefs: &[Vec<[i32; 64]>],
pts: Option<i64>,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("render: EOI before SOF"))?;
if sof.precision == 12 {
return render_from_coefs_12bit(state, coefs, pts);
}
let n_comp = sof.components.len();
let grayscale = n_comp == 1;
let width = sof.width as usize;
let height = sof.height as usize;
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let is_rgb = detect_rgb_3comp(sof, state.adobe_transform);
let out_format = if grayscale {
PixelFormat::Gray8
} else if is_rgb {
for c in &sof.components {
if c.h_factor != 1 || c.v_factor != 1 {
return Err(Error::unsupported(
"RGB JPEG: every component must declare H = V = 1",
));
}
}
PixelFormat::Rgb24
} else if n_comp == 3 {
let y = sof.components[0];
let cb = sof.components[1];
let cr = sof.components[2];
if cb.h_factor != cr.h_factor || cb.v_factor != cr.v_factor {
return Err(Error::unsupported(
"chroma components have different sampling factors",
));
}
if cb.h_factor != 1 || cb.v_factor != 1 {
return Err(Error::unsupported("chroma components must have factor 1"));
}
match (y.h_factor, y.v_factor) {
(1, 1) => PixelFormat::Yuv444P,
(2, 1) => PixelFormat::Yuv422P,
(2, 2) => PixelFormat::Yuv420P,
(4, 1) => PixelFormat::Yuv411P,
_ => {
return Err(Error::unsupported(format!(
"luma sampling {}x{}",
y.h_factor, y.v_factor
)))
}
}
} else if n_comp == 4 {
PixelFormat::Cmyk
} else {
return Err(Error::unsupported("2-component JPEG"));
};
let quant_tables: Vec<&QuantTable> = sof
.components
.iter()
.map(|c| {
state.quant[c.qt_id as usize]
.as_ref()
.ok_or_else(|| Error::invalid("quant table missing for component"))
})
.collect::<Result<Vec<_>>>()?;
let mut comp_buf: Vec<Vec<u8>> = Vec::with_capacity(n_comp);
let mut comp_stride: Vec<usize> = Vec::with_capacity(n_comp);
for c in &sof.components {
let w_full = mcus_x * 8 * c.h_factor as usize;
let h_full = mcus_y * 8 * c.v_factor as usize;
comp_buf.push(vec![0u8; w_full * h_full]);
comp_stride.push(w_full);
}
for (ci, c) in sof.components.iter().enumerate() {
let blocks_x = mcus_x * c.h_factor as usize;
let blocks_y = mcus_y * c.v_factor as usize;
let qt = quant_tables[ci];
let stride = comp_stride[ci];
let buf = &mut comp_buf[ci];
for by in 0..blocks_y {
for bx in 0..blocks_x {
let block = &coefs[ci][by * blocks_x + bx];
let mut fblock = [0.0f32; 64];
for k in 0..64 {
fblock[k] = block[k] as f32 * qt.values[k] as f32;
}
idct8x8(&mut fblock);
let dst_x0 = bx * 8;
let dst_y0 = by * 8;
for j in 0..8 {
for i in 0..8 {
let v = fblock[j * 8 + i] + 128.0;
let px = if v <= 0.0 {
0
} else if v >= 255.0 {
255
} else {
v.round() as u8
};
buf[(dst_y0 + j) * stride + dst_x0 + i] = px;
}
}
}
}
}
let mut planes: Vec<VideoPlane> = Vec::new();
match out_format {
PixelFormat::Gray8 => {
let stride = width;
let mut data = vec![0u8; stride * height];
let src_stride = comp_stride[0];
for y in 0..height {
data[y * stride..y * stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride..y * src_stride + width]);
}
planes.push(VideoPlane { stride, data });
}
PixelFormat::Rgb24 => {
let stride = width * 3;
let mut data = vec![0u8; stride * height];
let src_strides = [comp_stride[0], comp_stride[1], comp_stride[2]];
for y in 0..height {
let off = y * stride;
for x in 0..width {
data[off + x * 3] = comp_buf[0][y * src_strides[0] + x];
data[off + x * 3 + 1] = comp_buf[1][y * src_strides[1] + x];
data[off + x * 3 + 2] = comp_buf[2][y * src_strides[2] + x];
}
}
planes.push(VideoPlane { stride, data });
}
PixelFormat::Yuv444P
| PixelFormat::Yuv422P
| PixelFormat::Yuv420P
| PixelFormat::Yuv411P => {
let (c_w, c_h) = match out_format {
PixelFormat::Yuv444P => (width, height),
PixelFormat::Yuv422P => (width.div_ceil(2), height),
PixelFormat::Yuv420P => (width.div_ceil(2), height.div_ceil(2)),
PixelFormat::Yuv411P => (width.div_ceil(4), height),
_ => unreachable!(),
};
let y_stride = width;
let mut y_data = vec![0u8; y_stride * height];
let src_stride_y = comp_stride[0];
for y in 0..height {
y_data[y * y_stride..y * y_stride + width]
.copy_from_slice(&comp_buf[0][y * src_stride_y..y * src_stride_y + width]);
}
planes.push(VideoPlane {
stride: y_stride,
data: y_data,
});
for ci in [1usize, 2] {
let src_stride = comp_stride[ci];
let stride = c_w;
let mut data = vec![0u8; stride * c_h];
for y in 0..c_h {
data[y * stride..y * stride + c_w]
.copy_from_slice(&comp_buf[ci][y * src_stride..y * src_stride + c_w]);
}
planes.push(VideoPlane { stride, data });
}
}
PixelFormat::Cmyk => {
let stride = width * 4;
let mut data = vec![0u8; stride * height];
let fh: [usize; 4] = [
sof.components[0].h_factor as usize,
sof.components[1].h_factor as usize,
sof.components[2].h_factor as usize,
sof.components[3].h_factor as usize,
];
let fv: [usize; 4] = [
sof.components[0].v_factor as usize,
sof.components[1].v_factor as usize,
sof.components[2].v_factor as usize,
sof.components[3].v_factor as usize,
];
let transform = state.adobe_transform;
for y in 0..height {
for x in 0..width {
let mut s = [0u8; 4];
for ci in 0..4 {
let sx = x * fh[ci] / h_max;
let sy = y * fv[ci] / v_max;
s[ci] = comp_buf[ci][sy * comp_stride[ci] + sx];
}
let (c, m, yy, k) = match transform {
Some(2) => {
let y_s = s[0] as i32;
let cb = s[1] as i32 - 128;
let cr = s[2] as i32 - 128;
let r = (y_s + ((cr * 91881 + 32768) >> 16)).clamp(0, 255);
let g =
(y_s + ((-22554 * cb - 46802 * cr + 32768) >> 16)).clamp(0, 255);
let b = (y_s + ((cb * 116130 + 32768) >> 16)).clamp(0, 255);
(
(255 - r) as u8,
(255 - g) as u8,
(255 - b) as u8,
255 - s[3],
)
}
Some(0) => {
(255 - s[0], 255 - s[1], 255 - s[2], 255 - s[3])
}
_ => {
(s[0], s[1], s[2], s[3])
}
};
let o = y * stride + x * 4;
data[o] = c;
data[o + 1] = m;
data[o + 2] = yy;
data[o + 3] = k;
}
}
planes.push(VideoPlane { stride, data });
}
_ => unreachable!(),
}
Ok(VideoFrame { pts, planes })
}
fn render_from_coefs_12bit(
state: &JpegState,
coefs: &[Vec<[i32; 64]>],
pts: Option<i64>,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("render: EOI before SOF"))?;
let n_comp = sof.components.len();
let grayscale = n_comp == 1;
let width = sof.width as usize;
let height = sof.height as usize;
let h_max = sof.components.iter().map(|c| c.h_factor).max().unwrap_or(1) as usize;
let v_max = sof.components.iter().map(|c| c.v_factor).max().unwrap_or(1) as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let out_format = if grayscale {
PixelFormat::Gray12Le
} else if n_comp == 3 {
let y = sof.components[0];
let cb = sof.components[1];
let cr = sof.components[2];
if cb.h_factor != cr.h_factor || cb.v_factor != cr.v_factor {
return Err(Error::unsupported(
"12-bit: chroma components have different sampling factors",
));
}
if cb.h_factor != 1 || cb.v_factor != 1 {
return Err(Error::unsupported(
"12-bit: chroma components must have factor 1",
));
}
match (y.h_factor, y.v_factor) {
(1, 1) => PixelFormat::Yuv444P12Le,
(2, 1) => PixelFormat::Yuv422P12Le,
(2, 2) => PixelFormat::Yuv420P12Le,
_ => {
return Err(Error::unsupported(format!(
"12-bit: only 4:4:4 / 4:2:2 / 4:2:0 chroma sampling supported (got {}x{})",
y.h_factor, y.v_factor
)))
}
}
} else {
return Err(Error::unsupported(format!(
"12-bit: {n_comp}-component JPEGs not supported"
)));
};
let quant_tables: Vec<&QuantTable> = sof
.components
.iter()
.map(|c| {
state.quant[c.qt_id as usize]
.as_ref()
.ok_or_else(|| Error::invalid("quant table missing for component"))
})
.collect::<Result<Vec<_>>>()?;
let mut comp_buf: Vec<Vec<u16>> = Vec::with_capacity(n_comp);
let mut comp_stride: Vec<usize> = Vec::with_capacity(n_comp);
for c in &sof.components {
let w_full = mcus_x * 8 * c.h_factor as usize;
let h_full = mcus_y * 8 * c.v_factor as usize;
comp_buf.push(vec![0u16; w_full * h_full]);
comp_stride.push(w_full);
}
for (ci, c) in sof.components.iter().enumerate() {
let blocks_x = mcus_x * c.h_factor as usize;
let blocks_y = mcus_y * c.v_factor as usize;
let qt = quant_tables[ci];
let stride = comp_stride[ci];
let buf = &mut comp_buf[ci];
for by in 0..blocks_y {
for bx in 0..blocks_x {
let block = &coefs[ci][by * blocks_x + bx];
let mut fblock = [0.0f32; 64];
for k in 0..64 {
fblock[k] = block[k] as f32 * qt.values[k] as f32;
}
idct8x8(&mut fblock);
let dst_x0 = bx * 8;
let dst_y0 = by * 8;
for j in 0..8 {
for i in 0..8 {
let v = fblock[j * 8 + i] + 2048.0;
let px = if v <= 0.0 {
0
} else if v >= 4095.0 {
4095
} else {
v.round() as u16
};
buf[(dst_y0 + j) * stride + dst_x0 + i] = px;
}
}
}
}
}
let mut planes: Vec<VideoPlane> = Vec::new();
let emit_plane = |src: &[u16], src_stride: usize, w: usize, h: usize| -> VideoPlane {
let stride = w * 2;
let mut data = vec![0u8; stride * h];
for y in 0..h {
for x in 0..w {
let v = src[y * src_stride + x];
data[y * stride + x * 2] = (v & 0xFF) as u8;
data[y * stride + x * 2 + 1] = ((v >> 8) & 0xFF) as u8;
}
}
VideoPlane { stride, data }
};
match out_format {
PixelFormat::Gray12Le => {
planes.push(emit_plane(&comp_buf[0], comp_stride[0], width, height));
}
PixelFormat::Yuv420P12Le | PixelFormat::Yuv422P12Le | PixelFormat::Yuv444P12Le => {
let (c_w, c_h) = match out_format {
PixelFormat::Yuv444P12Le => (width, height),
PixelFormat::Yuv422P12Le => (width.div_ceil(2), height),
PixelFormat::Yuv420P12Le => (width.div_ceil(2), height.div_ceil(2)),
_ => unreachable!(),
};
planes.push(emit_plane(&comp_buf[0], comp_stride[0], width, height));
planes.push(emit_plane(&comp_buf[1], comp_stride[1], c_w, c_h));
planes.push(emit_plane(&comp_buf[2], comp_stride[2], c_w, c_h));
}
_ => unreachable!(),
}
Ok(VideoFrame { pts, planes })
}
fn decode_lossless_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
pts: Option<i64>,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if !matches!(sos.components.len(), 1 | 3 | 4) {
return Err(Error::unsupported(format!(
"lossless: {} component(s) — only 1, 3 and 4 are supported",
sos.components.len()
)));
}
if sos.components.len() != sof.components.len() {
return Err(Error::unsupported(
"lossless: non-interleaved multi-component scans are not supported",
));
}
if sos.components.len() == 4 && sof.precision != 8 {
return Err(Error::unsupported(
"lossless: 4-component scans require precision 8",
));
}
let predictor = sos.ss;
if !(1..=7).contains(&predictor) {
return Err(Error::invalid("lossless: predictor Ss must be in 1..=7"));
}
let pt = sos.al as u32; let precision = sof.precision as u32;
if pt >= precision {
return Err(Error::invalid("lossless: Pt >= precision"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let nc = sos.components.len();
let mut dc_tables: Vec<&HuffTable> = Vec::with_capacity(nc);
for sc in &sos.components {
let t = state.dc_huff[sc.dc_table as usize]
.as_ref()
.ok_or_else(|| Error::invalid("lossless: DC Huffman table missing"))?;
dc_tables.push(t);
}
let sample_bits = precision - pt;
let sample_max: u32 = 1u32 << sample_bits;
let sample_mask: u32 = sample_max - 1;
let origin: u32 = 1u32 << (sample_bits - 1);
let mut samples: Vec<Vec<u32>> = (0..nc).map(|_| vec![0u32; width * height]).collect();
let mut br = BitReader::new(scan);
let mut mcus_since_restart: u32 = 0;
let mut expected_rst: u8 = RST0;
let mut reset_pred = true;
for y in 0..height {
for x in 0..width {
if state.restart_interval != 0
&& mcus_since_restart != 0
&& mcus_since_restart % state.restart_interval as u32 == 0
{
br.bits = 0;
br.nbits = 0;
while br.saw_rst.is_none() {
let prev = br.pos;
match br.next_byte_with_stuff()? {
Some(_) => {}
None => {
if prev == br.pos {
break;
}
break;
}
}
}
if br.saw_rst.is_some() {
expected_rst = if expected_rst == RST7 {
RST0
} else {
expected_rst + 1
};
br.reset_at_restart();
reset_pred = true;
}
}
for ci in 0..nc {
let plane = &samples[ci];
let pred: u32 = if reset_pred {
origin
} else if y == 0 {
plane[y * width + x - 1]
} else if x == 0 {
plane[(y - 1) * width + x]
} else {
let ra = plane[y * width + x - 1];
let rb = plane[(y - 1) * width + x];
let rc = plane[(y - 1) * width + x - 1];
match predictor {
1 => ra,
2 => rb,
3 => rc,
4 => ra.wrapping_add(rb).wrapping_sub(rc),
5 => ra.wrapping_add(rb.wrapping_sub(rc) >> 1),
6 => rb.wrapping_add(ra.wrapping_sub(rc) >> 1),
7 => (ra.wrapping_add(rb)) >> 1,
_ => unreachable!(),
}
};
let s = decode_huff(&mut br, dc_tables[ci])? as u32;
if s > 16 {
return Err(Error::invalid("lossless: SSSS > 16"));
}
let residual: i32 = if s == 0 {
0
} else if s == 16 {
32_768
} else {
let bits = br.get_bits(s)? as i32;
extend(bits, s)
};
let sv = ((pred as i32).wrapping_add(residual) as u32) & sample_mask;
samples[ci][y * width + x] = sv;
}
reset_pred = false;
mcus_since_restart += 1;
}
}
shape_lossless_frame(&samples, nc, width, height, pt, precision, state, pts)
}
#[allow(clippy::too_many_arguments)]
fn shape_lossless_frame(
samples: &[Vec<u32>],
nc: usize,
width: usize,
height: usize,
pt: u32,
precision: u32,
state: &JpegState,
pts: Option<i64>,
) -> Result<VideoFrame> {
if nc == 4 {
let stride = width * 4;
let mut data = vec![0u8; stride * height];
let transform = state.adobe_transform;
for i in 0..width * height {
let s0 = (samples[0][i] << pt) as u8;
let s1 = (samples[1][i] << pt) as u8;
let s2 = (samples[2][i] << pt) as u8;
let s3 = (samples[3][i] << pt) as u8;
let (c, m, yy, k) = match transform {
Some(2) => {
let y_s = s0 as i32;
let cb = s1 as i32 - 128;
let cr = s2 as i32 - 128;
let r = (y_s + ((cr * 91881 + 32768) >> 16)).clamp(0, 255);
let g = (y_s + ((-22554 * cb - 46802 * cr + 32768) >> 16)).clamp(0, 255);
let b = (y_s + ((cb * 116130 + 32768) >> 16)).clamp(0, 255);
((255 - r) as u8, (255 - g) as u8, (255 - b) as u8, 255 - s3)
}
Some(0) => {
(255 - s0, 255 - s1, 255 - s2, 255 - s3)
}
_ => {
(s0, s1, s2, s3)
}
};
let o = i * 4;
data[o] = c;
data[o + 1] = m;
data[o + 2] = yy;
data[o + 3] = k;
}
return Ok(VideoFrame {
pts,
planes: vec![VideoPlane { stride, data }],
});
}
if nc == 3 {
if precision == 8 {
let stride = width * 3;
let mut data = vec![0u8; stride * height];
for i in 0..width * height {
data[i * 3] = (samples[0][i] << pt) as u8;
data[i * 3 + 1] = (samples[1][i] << pt) as u8;
data[i * 3 + 2] = (samples[2][i] << pt) as u8;
}
return Ok(VideoFrame {
pts,
planes: vec![VideoPlane { stride, data }],
});
}
if matches!(precision, 10 | 12 | 14) {
let stride = width * 2;
let mut out_planes: Vec<VideoPlane> = Vec::with_capacity(3);
for si in 0..3 {
let mut data = vec![0u8; stride * height];
for i in 0..width * height {
let v = (samples[si][i] << pt) as u16;
data[i * 2] = (v & 0xFF) as u8;
data[i * 2 + 1] = (v >> 8) as u8;
}
out_planes.push(VideoPlane { stride, data });
}
return Ok(VideoFrame {
pts,
planes: out_planes,
});
}
let stride = width * 6;
let mut data = vec![0u8; stride * height];
for i in 0..width * height {
let c0 = (samples[0][i] << pt) as u16;
let c1 = (samples[1][i] << pt) as u16;
let c2 = (samples[2][i] << pt) as u16;
data[i * 6] = (c0 & 0xFF) as u8;
data[i * 6 + 1] = (c0 >> 8) as u8;
data[i * 6 + 2] = (c1 & 0xFF) as u8;
data[i * 6 + 3] = (c1 >> 8) as u8;
data[i * 6 + 4] = (c2 & 0xFF) as u8;
data[i * 6 + 5] = (c2 >> 8) as u8;
}
return Ok(VideoFrame {
pts,
planes: vec![VideoPlane { stride, data }],
});
}
let out_format = match precision {
8 => PixelFormat::Gray8,
10 => PixelFormat::Gray10Le,
12 => PixelFormat::Gray12Le,
_ => PixelFormat::Gray16Le,
};
let plane = if out_format == PixelFormat::Gray8 {
let stride = width;
let mut data = vec![0u8; stride * height];
for i in 0..width * height {
data[i] = (samples[0][i] << pt) as u8;
}
VideoPlane { stride, data }
} else {
let stride = width * 2;
let mut data = vec![0u8; stride * height];
for i in 0..width * height {
let v = (samples[0][i] << pt) as u16;
data[i * 2] = (v & 0xFF) as u8;
data[i * 2 + 1] = (v >> 8) as u8;
}
VideoPlane { stride, data }
};
Ok(VideoFrame {
pts,
planes: vec![plane],
})
}
fn decode_lossless_arith_scan(
state: &JpegState,
sos: &SosInfo,
scan: &[u8],
pts: Option<i64>,
) -> Result<VideoFrame> {
let sof = state
.sof
.as_ref()
.ok_or_else(|| Error::invalid("SOS before SOF"))?;
if sos.components.len() != sof.components.len() {
return Err(Error::unsupported(
"lossless arith: non-interleaved multi-component scans are not supported",
));
}
let predictor = sos.ss;
if !(1..=7).contains(&predictor) {
return Err(Error::invalid(
"lossless arith: predictor Ss must be in 1..=7",
));
}
let pt = sos.al as u32; let precision = sof.precision as u32;
if pt >= precision {
return Err(Error::invalid("lossless arith: Pt >= precision"));
}
let width = sof.width as usize;
let height = sof.height as usize;
let nc = sos.components.len();
let mut stats: Vec<LosslessStats> = sos
.components
.iter()
.map(|sc| {
let mut s = LosslessStats::new();
if let Some(cond) = state.arith_dc[sc.dc_table as usize].as_ref() {
s.l = cond.l;
s.u = cond.u;
}
s
})
.collect();
let sample_bits = precision - pt;
let sample_max: u32 = 1u32 << sample_bits;
let sample_mask: u32 = sample_max - 1;
let origin: u32 = 1u32 << (sample_bits - 1);
let mut samples: Vec<Vec<u32>> = (0..nc).map(|_| vec![0u32; width * height]).collect();
let mut prev_diff: Vec<Vec<i32>> = (0..nc).map(|_| vec![0i32; width]).collect();
let mut cur_diff: Vec<Vec<i32>> = (0..nc).map(|_| vec![0i32; width]).collect();
let mut scan_pos = 0usize;
let mut decoder = ArithDecoder::new(scan);
let mut samples_since_restart: u32 = 0;
let mut reset_pred = true; let mut first_line_y = 0usize;
for y in 0..height {
for x in 0..width {
if state.restart_interval != 0
&& samples_since_restart != 0
&& samples_since_restart % state.restart_interval as u32 == 0
{
scan_pos = locate_next_marker_after(scan, scan_pos);
if scan_pos >= scan.len() {
return Err(Error::invalid(
"lossless arith: missing restart marker mid-scan",
));
}
for s in stats.iter_mut() {
s.reset();
}
for row in prev_diff.iter_mut() {
row.fill(0);
}
for row in cur_diff.iter_mut() {
row.fill(0);
}
decoder = ArithDecoder::new(&scan[scan_pos..]);
reset_pred = true;
first_line_y = y;
}
for ci in 0..nc {
let plane = &samples[ci];
let pred: u32 = if reset_pred {
origin
} else if y == first_line_y {
plane[y * width + x - 1]
} else if x == 0 {
plane[(y - 1) * width + x]
} else {
let ra = plane[y * width + x - 1];
let rb = plane[(y - 1) * width + x];
let rc = plane[(y - 1) * width + x - 1];
match predictor {
1 => ra,
2 => rb,
3 => rc,
4 => ra.wrapping_add(rb).wrapping_sub(rc),
5 => ra.wrapping_add(rb.wrapping_sub(rc) >> 1),
6 => rb.wrapping_add(ra.wrapping_sub(rc) >> 1),
7 => (ra.wrapping_add(rb)) >> 1,
_ => unreachable!(),
}
};
let da = if x == 0 { 0 } else { cur_diff[ci][x - 1] };
let db = prev_diff[ci][x];
let diff = arith_decode_lossless_diff(&mut decoder, &mut stats[ci], da, db)?;
cur_diff[ci][x] = diff;
let sv = ((pred as i32).wrapping_add(diff) as u32) & sample_mask;
samples[ci][y * width + x] = sv;
}
reset_pred = false;
samples_since_restart += 1;
}
std::mem::swap(&mut prev_diff, &mut cur_diff);
}
shape_lossless_frame(&samples, nc, width, height, pt, precision, state, pts)
}
#[cfg(all(test, feature = "registry"))]
mod prog_tests {
use super::*;
use crate::jpeg::huffman::{HuffTable, STD_DC_LUMA_BITS, STD_DC_LUMA_VALS};
use crate::jpeg::parser::SofComponent;
fn make_state(progressive: bool, width: u16, height: u16) -> JpegState {
let mut s = JpegState::new();
s.sof = Some(SofInfo {
precision: 8,
height,
width,
components: vec![SofComponent {
id: 1,
h_factor: 1,
v_factor: 1,
qt_id: 0,
}],
});
s.progressive = progressive;
s.quant[0] = Some(QuantTable { values: [1u16; 64] });
s.dc_huff[0] = Some(HuffTable::build(&STD_DC_LUMA_BITS, &STD_DC_LUMA_VALS).unwrap());
s
}
#[test]
fn accumulator_sizes() {
let sof = SofInfo {
precision: 8,
height: 16,
width: 16,
components: vec![
SofComponent {
id: 1,
h_factor: 2,
v_factor: 2,
qt_id: 0,
},
SofComponent {
id: 2,
h_factor: 1,
v_factor: 1,
qt_id: 1,
},
SofComponent {
id: 3,
h_factor: 1,
v_factor: 1,
qt_id: 1,
},
],
};
let coefs = init_coef_buffers(&sof).unwrap();
assert_eq!(coefs[0].len(), 4);
assert_eq!(coefs[1].len(), 1);
assert_eq!(coefs[2].len(), 1);
for blk in &coefs[0] {
assert!(blk.iter().all(|&v| v == 0));
}
}
#[test]
fn dc_first_pass_shifts_by_al() {
let t = HuffTable::build(&STD_DC_LUMA_BITS, &STD_DC_LUMA_VALS).unwrap();
let code = t.encode[3];
let mut bw = ProgTestBitWriter::new();
bw.put(code.code as u32, code.len as u32);
bw.put(0b101, 3);
let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut prev_dc = 0i32;
let mut block = [0i32; 64];
prog_decode_dc(&mut br, &t, &mut prev_dc, &mut block, 0, 2).unwrap();
assert_eq!(prev_dc, 5);
assert_eq!(block[0], 20);
}
#[test]
fn dc_refine_appends_bit() {
let mut bw = ProgTestBitWriter::new();
bw.put(1, 1);
let buf = bw.finish();
let t = HuffTable::build(&STD_DC_LUMA_BITS, &STD_DC_LUMA_VALS).unwrap();
let mut br = BitReader::new(&buf);
let mut prev_dc = 0i32;
let mut block = [0i32; 64];
block[0] = 0b1000;
prog_decode_dc(&mut br, &t, &mut prev_dc, &mut block, 1, 0).unwrap();
assert_eq!(block[0], 0b1001);
}
#[test]
fn state_helper_is_coherent() {
let s = make_state(true, 8, 8);
assert!(s.progressive);
assert!(s.sof.is_some());
}
struct ProgTestBitWriter {
out: Vec<u8>,
bits: u32,
nbits: u32,
}
impl ProgTestBitWriter {
fn new() -> Self {
Self {
out: Vec::new(),
bits: 0,
nbits: 0,
}
}
fn put(&mut self, val: u32, n: u32) {
self.bits = (self.bits << n) | (val & ((1u32 << n) - 1));
self.nbits += n;
while self.nbits >= 8 {
self.nbits -= 8;
let b = ((self.bits >> self.nbits) & 0xFF) as u8;
self.out.push(b);
if b == 0xFF {
self.out.push(0x00);
}
}
}
fn finish(mut self) -> Vec<u8> {
if self.nbits > 0 {
let pad = 8 - self.nbits;
let b = (((self.bits << pad) | ((1u32 << pad) - 1)) & 0xFF) as u8;
self.out.push(b);
if b == 0xFF {
self.out.push(0x00);
}
}
self.out
}
}
#[test]
fn ac_first_single_coef_then_eob() {
use crate::jpeg::huffman::{STD_AC_LUMA_BITS, STD_AC_LUMA_VALS};
let ac = HuffTable::build(&STD_AC_LUMA_BITS, &STD_AC_LUMA_VALS).unwrap();
let mut bw = ProgTestBitWriter::new();
let c = ac.encode[0x01];
bw.put(c.code as u32, c.len as u32);
bw.put(1, 1); let c0 = ac.encode[0x00];
bw.put(c0.code as u32, c0.len as u32);
let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut block = [0i32; 64];
let mut eob = 0u32;
prog_decode_ac_first(&mut br, &ac, &mut block, 1, 63, 0, &mut eob).unwrap();
assert_eq!(block[ZIGZAG[1]], 1);
for i in 0..64 {
if i != ZIGZAG[1] {
assert_eq!(block[i], 0, "unexpected nonzero at {i}");
}
}
}
#[test]
fn ac_first_pass_shifts_by_al() {
use crate::jpeg::huffman::{STD_AC_LUMA_BITS, STD_AC_LUMA_VALS};
let ac = HuffTable::build(&STD_AC_LUMA_BITS, &STD_AC_LUMA_VALS).unwrap();
let mut bw = ProgTestBitWriter::new();
let c = ac.encode[0x01];
bw.put(c.code as u32, c.len as u32);
bw.put(1, 1); let c0 = ac.encode[0x00];
bw.put(c0.code as u32, c0.len as u32);
let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut block = [0i32; 64];
let mut eob = 0u32;
prog_decode_ac_first(&mut br, &ac, &mut block, 1, 63, 2, &mut eob).unwrap();
assert_eq!(block[ZIGZAG[1]], 4); }
#[test]
fn ac_refine_extends_existing_coef() {
use crate::jpeg::huffman::{STD_AC_LUMA_BITS, STD_AC_LUMA_VALS};
let ac = HuffTable::build(&STD_AC_LUMA_BITS, &STD_AC_LUMA_VALS).unwrap();
let mut bw = ProgTestBitWriter::new();
let c0 = ac.encode[0x00];
bw.put(c0.code as u32, c0.len as u32);
bw.put(1, 1); let buf = bw.finish();
let mut br = BitReader::new(&buf);
let mut block = [0i32; 64];
block[ZIGZAG[1]] = 4; let mut eob = 0u32;
prog_decode_ac_refine(&mut br, &ac, &mut block, 1, 63, 1, &mut eob).unwrap();
assert_eq!(block[ZIGZAG[1]], 6);
}
}
#[cfg(all(test, feature = "registry"))]
mod non_interleaved_tests {
use super::*;
use crate::encoder::{encode_jpeg, encode_jpeg_non_interleaved};
use crate::registry::make_decoder;
use oxideav_core::{CodecId, CodecParameters, Frame, Packet, TimeBase};
fn make_frame(w: u32, h: u32, pix: PixelFormat) -> VideoFrame {
let (cw, ch) = match pix {
PixelFormat::Yuv444P => (w, h),
PixelFormat::Yuv422P => (w.div_ceil(2), h),
PixelFormat::Yuv420P => (w.div_ceil(2), h.div_ceil(2)),
_ => panic!("unsupported"),
};
let mut y = vec![0u8; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
y[j * w as usize + i] = (((i + j * 3) * 7) % 255) as u8;
}
}
let mut cb = vec![0u8; (cw * ch) as usize];
let mut cr = vec![0u8; (cw * ch) as usize];
for j in 0..ch as usize {
for i in 0..cw as usize {
cb[j * cw as usize + i] = ((128 + i as i32 / 2) as u8).clamp(0, 255);
cr[j * cw as usize + i] = ((128 + j as i32 / 2) as u8).clamp(0, 255);
}
}
VideoFrame {
pts: Some(0),
planes: vec![
VideoPlane {
stride: w as usize,
data: y,
},
VideoPlane {
stride: cw as usize,
data: cb,
},
VideoPlane {
stride: cw as usize,
data: cr,
},
],
}
}
fn assert_matches_interleaved(w: u32, h: u32, pix: PixelFormat) {
let frame = make_frame(w, h, pix);
let base = encode_jpeg(&frame, w, h, pix, 75).expect("interleaved encode");
let non =
encode_jpeg_non_interleaved(&frame, w, h, pix, 75).expect("non-interleaved encode");
let sos_count = non.windows(2).filter(|w| w == &[0xFF, 0xDA]).count();
assert_eq!(sos_count, 3, "expected 3 SOS segments");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec_a = make_decoder(&dec_params).unwrap();
let mut dec_b = make_decoder(&dec_params).unwrap();
dec_a
.send_packet(&Packet::new(0, TimeBase::new(1, 30), base))
.unwrap();
dec_b
.send_packet(&Packet::new(0, TimeBase::new(1, 30), non))
.unwrap();
let Frame::Video(va) = dec_a.receive_frame().unwrap() else {
panic!()
};
let Frame::Video(vb) = dec_b.receive_frame().unwrap() else {
panic!()
};
assert_eq!(va.planes.len(), vb.planes.len());
for (pi, (pa, pb)) in va.planes.iter().zip(vb.planes.iter()).enumerate() {
assert_eq!(
pa.data, pb.data,
"plane {pi} mismatch between interleaved and non-interleaved decodes"
);
}
}
#[test]
fn non_interleaved_yuv420p_matches_interleaved() {
assert_matches_interleaved(32, 16, PixelFormat::Yuv420P);
}
#[test]
fn non_interleaved_yuv422p_matches_interleaved() {
assert_matches_interleaved(24, 24, PixelFormat::Yuv422P);
}
#[test]
fn non_interleaved_yuv444p_matches_interleaved() {
assert_matches_interleaved(16, 16, PixelFormat::Yuv444P);
}
}
#[cfg(all(test, feature = "registry"))]
mod cmyk_tests {
use crate::encoder::{encode_jpeg_cmyk_1111, encode_jpeg_progressive_cmyk_1111};
use crate::registry::make_decoder;
use oxideav_core::{CodecId, CodecParameters, Frame, Packet, TimeBase};
fn psnr(a: &[u8], b: &[u8]) -> f64 {
assert_eq!(a.len(), b.len());
let mut sse: f64 = 0.0;
for i in 0..a.len() {
let d = a[i] as f64 - b[i] as f64;
sse += d * d;
}
if sse == 0.0 {
return 99.0;
}
20.0 * (255.0 / (sse / a.len() as f64).sqrt()).log10()
}
fn make_cmyk_planes(w: usize, h: usize) -> [Vec<u8>; 4] {
let mut c = vec![0u8; w * h];
let mut m = vec![0u8; w * h];
let mut y = vec![0u8; w * h];
let mut k = vec![0u8; w * h];
for j in 0..h {
for i in 0..w {
c[j * w + i] = ((i * 255 / w.max(1)) as u32).min(255) as u8;
m[j * w + i] = ((j * 255 / h.max(1)) as u32).min(255) as u8;
y[j * w + i] = (((i + j) * 255 / (w + h).max(1)) as u32).min(255) as u8;
k[j * w + i] = ((((i ^ j) * 7) & 0xFF) as u8) / 2;
}
}
[c, m, y, k]
}
#[test]
fn cmyk_plain_roundtrip() {
let w = 32u32;
let h = 16u32;
let planes = make_cmyk_planes(w as usize, h as usize);
let refs: [&[u8]; 4] = [&planes[0], &planes[1], &planes[2], &planes[3]];
let strides = [w as usize; 4];
let data =
encode_jpeg_cmyk_1111(w, h, &refs, &strides, 90, None).expect("encode plain CMYK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, (w * 4) as usize);
for (ci, src) in planes.iter().enumerate() {
let mut got = Vec::with_capacity(src.len());
for j in 0..h as usize {
for i in 0..w as usize {
got.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + ci]);
}
}
let p = psnr(src, &got);
assert!(p >= 30.0, "component {ci} PSNR too low: {p:.2}");
}
}
#[test]
fn cmyk_adobe_inverted_roundtrip() {
let w = 16u32;
let h = 16u32;
let planes = make_cmyk_planes(w as usize, h as usize);
let refs: [&[u8]; 4] = [&planes[0], &planes[1], &planes[2], &planes[3]];
let strides = [w as usize; 4];
let data =
encode_jpeg_cmyk_1111(w, h, &refs, &strides, 90, Some(0)).expect("encode Adobe CMYK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
for (ci, src) in planes.iter().enumerate() {
let mut got = Vec::with_capacity(src.len());
for j in 0..h as usize {
for i in 0..w as usize {
got.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + ci]);
}
}
let p = psnr(src, &got);
assert!(p >= 30.0, "Adobe CMYK component {ci} PSNR too low: {p:.2}");
}
}
#[test]
fn ycck_roundtrip_k_plane_matches() {
let w = 16u32;
let h = 16u32;
let mut yp = vec![0u8; (w * h) as usize];
let mut cb = vec![128u8; (w * h) as usize];
let mut cr = vec![128u8; (w * h) as usize];
let mut k = vec![0u8; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
let idx = j * w as usize + i;
yp[idx] = 128;
cb[idx] = 128;
cr[idx] = 128;
k[idx] = ((i * 255 / (w as usize - 1).max(1)) as u32).min(255) as u8;
}
}
let refs: [&[u8]; 4] = [&yp, &cb, &cr, &k];
let strides = [w as usize; 4];
let data = encode_jpeg_cmyk_1111(w, h, &refs, &strides, 90, Some(2)).expect("encode YCCK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
let mut got_k = Vec::with_capacity((w * h) as usize);
for j in 0..h as usize {
for i in 0..w as usize {
got_k.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + 3]);
}
}
let p = psnr(&k, &got_k);
assert!(p >= 30.0, "YCCK K plane PSNR too low: {p:.2}");
}
#[test]
fn cmyk_progressive_plain_roundtrip() {
let w = 32u32;
let h = 16u32;
let planes = make_cmyk_planes(w as usize, h as usize);
let refs: [&[u8]; 4] = [&planes[0], &planes[1], &planes[2], &planes[3]];
let strides = [w as usize; 4];
let data = encode_jpeg_progressive_cmyk_1111(w, h, &refs, &strides, 90, None)
.expect("encode progressive plain CMYK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
assert_eq!(v.planes.len(), 1, "expected one packed Cmyk plane");
assert_eq!(
v.planes[0].stride,
(w * 4) as usize,
"packed Cmyk row stride = 4 × width"
);
for (ci, src) in planes.iter().enumerate() {
let mut got = Vec::with_capacity(src.len());
for j in 0..h as usize {
for i in 0..w as usize {
got.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + ci]);
}
}
let p = psnr(src, &got);
assert!(
p >= 30.0,
"progressive plain CMYK component {ci} PSNR too low: {p:.2}"
);
}
}
#[test]
fn cmyk_progressive_adobe_inverted_roundtrip() {
let w = 16u32;
let h = 16u32;
let planes = make_cmyk_planes(w as usize, h as usize);
let refs: [&[u8]; 4] = [&planes[0], &planes[1], &planes[2], &planes[3]];
let strides = [w as usize; 4];
let data = encode_jpeg_progressive_cmyk_1111(w, h, &refs, &strides, 90, Some(0))
.expect("encode progressive Adobe CMYK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
for (ci, src) in planes.iter().enumerate() {
let mut got = Vec::with_capacity(src.len());
for j in 0..h as usize {
for i in 0..w as usize {
got.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + ci]);
}
}
let p = psnr(src, &got);
assert!(
p >= 30.0,
"progressive Adobe CMYK component {ci} PSNR too low: {p:.2}"
);
}
}
#[test]
fn ycck_progressive_k_plane_matches() {
let w = 16u32;
let h = 16u32;
let mut yp = vec![0u8; (w * h) as usize];
let mut cb = vec![128u8; (w * h) as usize];
let mut cr = vec![128u8; (w * h) as usize];
let mut k = vec![0u8; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
let idx = j * w as usize + i;
yp[idx] = 128;
cb[idx] = 128;
cr[idx] = 128;
k[idx] = ((i * 255 / (w as usize - 1).max(1)) as u32).min(255) as u8;
}
}
let refs: [&[u8]; 4] = [&yp, &cb, &cr, &k];
let strides = [w as usize; 4];
let data = encode_jpeg_progressive_cmyk_1111(w, h, &refs, &strides, 90, Some(2))
.expect("encode progressive YCCK");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
let mut got_k = Vec::with_capacity((w * h) as usize);
for j in 0..h as usize {
for i in 0..w as usize {
got_k.push(v.planes[0].data[j * v.planes[0].stride + i * 4 + 3]);
}
}
let p = psnr(&k, &got_k);
assert!(p >= 30.0, "progressive YCCK K plane PSNR too low: {p:.2}");
}
#[test]
fn cmyk_progressive_p12_rejected() {
let mut data: Vec<u8> = vec![
0xFF,
crate::jpeg::markers::SOI,
0xFF,
crate::jpeg::markers::SOF2,
];
let length: u16 = 2 + 1 + 2 + 2 + 1 + 4 * 3;
data.extend_from_slice(&length.to_be_bytes());
data.push(12); data.extend_from_slice(&16u16.to_be_bytes()); data.extend_from_slice(&16u16.to_be_bytes()); data.push(4); for id in 1u8..=4 {
data.push(id);
data.push(0x11); data.push(0); }
data.push(0xFF);
data.push(crate::jpeg::markers::EOI);
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(16);
dec_params.height = Some(16);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let err = dec.receive_frame().expect_err("expected Unsupported");
assert!(
matches!(err, oxideav_core::Error::Unsupported(_)),
"expected Unsupported, got {err:?}"
);
}
}
#[cfg(all(test, feature = "registry"))]
mod precision_12_tests {
use crate::encoder::{
encode_grayscale_jpeg_12bit, encode_yuv_jpeg_12bit, encode_yuv_jpeg_progressive_12bit,
};
use crate::registry::make_decoder;
use oxideav_core::{CodecId, CodecParameters, Frame, Packet, PixelFormat, TimeBase};
#[test]
fn gray_12bit_roundtrip() {
let w = 16u32;
let h = 16u32;
let mut samples = vec![0u16; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
samples[j * w as usize + i] = 2000 + ((i + j) as u16);
}
}
let stride = w as usize;
let data =
encode_grayscale_jpeg_12bit(w, h, &samples, stride, 90).expect("encode 12-bit gray");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, (w * 2) as usize);
let mut got = Vec::with_capacity((w * h) as usize);
for j in 0..h as usize {
for i in 0..w as usize {
let o = j * v.planes[0].stride + i * 2;
got.push(v.planes[0].data[o] as u16 | ((v.planes[0].data[o + 1] as u16) << 8));
}
}
for (orig, dec) in samples.iter().zip(got.iter()) {
let diff = (*orig as i32 - *dec as i32).abs();
assert!(diff < 16, "12-bit roundtrip diff too large: {diff}");
}
}
fn build_yuv_12bit(
w: usize,
h: usize,
h_factor: u8,
v_factor: u8,
) -> (Vec<u16>, Vec<u16>, Vec<u16>, usize, usize) {
let c_w = w.div_ceil(h_factor as usize);
let c_h = h.div_ceil(v_factor as usize);
let mut y = vec![0u16; w * h];
let mut cb = vec![0u16; c_w * c_h];
let mut cr = vec![0u16; c_w * c_h];
for j in 0..h {
for i in 0..w {
y[j * w + i] = 2000 + ((i + j) as u16);
}
}
for j in 0..c_h {
for i in 0..c_w {
cb[j * c_w + i] = 2040 + ((i ^ j) as u16 & 0x07);
cr[j * c_w + i] = 2056 + (((i + 2 * j) as u16) & 0x07);
}
}
(y, cb, cr, c_w, c_h)
}
fn unpack_le_u16_plane(data: &[u8], stride: usize, w: usize, h: usize) -> Vec<u16> {
let mut out = Vec::with_capacity(w * h);
for j in 0..h {
for i in 0..w {
let o = j * stride + i * 2;
out.push(data[o] as u16 | ((data[o + 1] as u16) << 8));
}
}
out
}
fn assert_plane_close(label: &str, orig: &[u16], dec: &[u16]) {
assert_eq!(orig.len(), dec.len(), "{label}: length mismatch");
for (k, (o, d)) in orig.iter().zip(dec.iter()).enumerate() {
let diff = (*o as i32 - *d as i32).abs();
assert!(
diff < 24,
"{label}: idx {k} diff too large (orig={o} dec={d})"
);
}
}
fn run_yuv_12bit_roundtrip(
w: u32,
h: u32,
h_factor: u8,
v_factor: u8,
expect_pix: PixelFormat,
) {
let (y, cb, cr, c_w, c_h) = build_yuv_12bit(w as usize, h as usize, h_factor, v_factor);
let data = encode_yuv_jpeg_12bit(w, h, &y, &cb, &cr, h_factor, v_factor, 90)
.expect("encode 12-bit yuv");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!("decoder did not emit a video frame")
};
assert_eq!(v.planes.len(), 3, "expected three planes");
assert_eq!(v.planes[0].stride, (w * 2) as usize, "Y stride");
assert_eq!(v.planes[1].stride, c_w * 2, "Cb stride");
assert_eq!(v.planes[2].stride, c_w * 2, "Cr stride");
let _ = expect_pix;
let got_y = unpack_le_u16_plane(
&v.planes[0].data,
v.planes[0].stride,
w as usize,
h as usize,
);
let got_cb = unpack_le_u16_plane(&v.planes[1].data, v.planes[1].stride, c_w, c_h);
let got_cr = unpack_le_u16_plane(&v.planes[2].data, v.planes[2].stride, c_w, c_h);
assert_plane_close("Y", &y, &got_y);
assert_plane_close("Cb", &cb, &got_cb);
assert_plane_close("Cr", &cr, &got_cr);
}
#[test]
fn yuv444_12bit_roundtrip() {
run_yuv_12bit_roundtrip(16, 16, 1, 1, PixelFormat::Yuv444P12Le);
}
#[test]
fn yuv422_12bit_roundtrip() {
run_yuv_12bit_roundtrip(16, 16, 2, 1, PixelFormat::Yuv422P12Le);
}
#[test]
fn yuv420_12bit_roundtrip_via_yuv_helper() {
run_yuv_12bit_roundtrip(16, 16, 2, 2, PixelFormat::Yuv420P12Le);
}
#[test]
fn yuv_12bit_4x1_luma_rejected() {
let w = 16u32;
let h = 16u32;
let (y, cb, cr, _c_w, _c_h) = build_yuv_12bit(w as usize, h as usize, 4, 1);
let data =
encode_yuv_jpeg_12bit(w, h, &y, &cb, &cr, 4, 1, 90).expect("encode 12-bit yuv 4:1:1");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let err = dec.receive_frame().expect_err("expected Unsupported");
assert!(
matches!(err, oxideav_core::Error::Unsupported(_)),
"expected Unsupported, got {err:?}"
);
}
fn run_progressive_yuv_12bit_roundtrip(
w: u32,
h: u32,
h_factor: u8,
v_factor: u8,
expect_pix: PixelFormat,
) {
let (y, cb, cr, c_w, c_h) = build_yuv_12bit(w as usize, h as usize, h_factor, v_factor);
let data = encode_yuv_jpeg_progressive_12bit(w, h, &y, &cb, &cr, h_factor, v_factor, 90)
.expect("encode 12-bit progressive yuv");
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!("decoder did not emit a video frame")
};
assert_eq!(v.planes.len(), 3, "expected three planes");
assert_eq!(v.planes[0].stride, (w * 2) as usize, "Y stride");
assert_eq!(v.planes[1].stride, c_w * 2, "Cb stride");
assert_eq!(v.planes[2].stride, c_w * 2, "Cr stride");
let _ = expect_pix;
let got_y = unpack_le_u16_plane(
&v.planes[0].data,
v.planes[0].stride,
w as usize,
h as usize,
);
let got_cb = unpack_le_u16_plane(&v.planes[1].data, v.planes[1].stride, c_w, c_h);
let got_cr = unpack_le_u16_plane(&v.planes[2].data, v.planes[2].stride, c_w, c_h);
assert_plane_close("Y", &y, &got_y);
assert_plane_close("Cb", &cb, &got_cb);
assert_plane_close("Cr", &cr, &got_cr);
}
#[test]
fn yuv444_12bit_progressive_roundtrip() {
run_progressive_yuv_12bit_roundtrip(16, 16, 1, 1, PixelFormat::Yuv444P12Le);
}
#[test]
fn yuv422_12bit_progressive_roundtrip() {
run_progressive_yuv_12bit_roundtrip(16, 16, 2, 1, PixelFormat::Yuv422P12Le);
}
#[test]
fn yuv420_12bit_progressive_roundtrip() {
run_progressive_yuv_12bit_roundtrip(16, 16, 2, 2, PixelFormat::Yuv420P12Le);
}
}
#[cfg(all(test, feature = "registry"))]
mod lossless_tests {
use super::{decode_jpeg, Error};
use crate::encoder::encode_lossless_grayscale_jpeg_8bit;
use crate::registry::make_decoder;
use oxideav_core::{CodecId, CodecParameters, Frame, Packet, TimeBase};
#[test]
fn lossless_8bit_gray_exact_roundtrip() {
let w = 24u32;
let h = 16u32;
let mut samples = vec![0u8; (w * h) as usize];
for j in 0..h as usize {
for i in 0..w as usize {
samples[j * w as usize + i] =
((i as i32 * 3 + j as i32 * 5 + ((i ^ j) as i32 & 7)) & 0xFF) as u8;
}
}
let data = encode_lossless_grayscale_jpeg_8bit(w, h, &samples, w as usize)
.expect("encode lossless");
assert!(
data.windows(2).any(|x| x == [0xFF, 0xC3]),
"SOF3 marker missing from lossless output"
);
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(w);
dec_params.height = Some(h);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), data))
.unwrap();
let Frame::Video(v) = dec.receive_frame().unwrap() else {
panic!()
};
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w as usize);
for j in 0..h as usize {
for i in 0..w as usize {
let got = v.planes[0].data[j * v.planes[0].stride + i];
let want = samples[j * w as usize + i];
assert_eq!(got, want, "mismatch at ({i},{j})");
}
}
}
fn put_seg(out: &mut Vec<u8>, marker: u8, payload: &[u8]) {
out.extend_from_slice(&[0xFF, marker]);
out.extend_from_slice(&((payload.len() + 2) as u16).to_be_bytes());
out.extend_from_slice(payload);
}
#[allow(clippy::too_many_arguments)]
fn encode_sof11_jpeg(
width: usize,
height: usize,
planes: &[Vec<u32>],
precision: u8,
predictor: u8,
pt: u8,
restart_interval: u16,
dac_lu: Option<(u8, u8)>,
) -> Vec<u8> {
use crate::jpeg::arith::{encode_lossless_diff, ArithEncoder, LosslessStats};
let nc = planes.len();
let mut out = vec![0xFF, 0xD8]; if let Some((l, u)) = dac_lu {
put_seg(&mut out, 0xCC, &[0x00, (u << 4) | l]);
}
if restart_interval != 0 {
put_seg(&mut out, 0xDD, &restart_interval.to_be_bytes());
}
let mut sof = vec![precision];
sof.extend_from_slice(&(height as u16).to_be_bytes());
sof.extend_from_slice(&(width as u16).to_be_bytes());
sof.push(nc as u8);
for ci in 0..nc {
sof.extend_from_slice(&[ci as u8 + 1, 0x11, 0]);
}
put_seg(&mut out, 0xCB, &sof); let mut sos = vec![nc as u8];
for ci in 0..nc {
sos.extend_from_slice(&[ci as u8 + 1, 0x00]);
}
sos.extend_from_slice(&[predictor, 0, pt]);
put_seg(&mut out, 0xDA, &sos);
let sample_bits = (precision - pt) as u32;
let origin: u32 = 1 << (sample_bits - 1);
let (l, u) = dac_lu.unwrap_or((0, 1));
let mut stats: Vec<LosslessStats> = (0..nc)
.map(|_| {
let mut s = LosslessStats::new();
s.l = l;
s.u = u;
s
})
.collect();
let mut prev_diff = vec![vec![0i32; width]; nc];
let mut cur_diff = vec![vec![0i32; width]; nc];
let mut enc = ArithEncoder::new();
let mut since_restart = 0u32;
let mut rst = 0u8;
let mut reset_pred = true;
let mut first_line_y = 0usize;
for y in 0..height {
for x in 0..width {
if restart_interval != 0
&& since_restart != 0
&& since_restart % restart_interval as u32 == 0
{
out.extend_from_slice(&std::mem::take(&mut enc).finish());
out.extend_from_slice(&[0xFF, 0xD0 + rst]);
rst = (rst + 1) % 8;
for s in stats.iter_mut() {
s.reset();
}
for r in prev_diff.iter_mut() {
r.fill(0);
}
for r in cur_diff.iter_mut() {
r.fill(0);
}
reset_pred = true;
first_line_y = y;
}
for ci in 0..nc {
let plane = &planes[ci];
let pred: u32 = if reset_pred {
origin
} else if y == first_line_y {
plane[y * width + x - 1]
} else if x == 0 {
plane[(y - 1) * width + x]
} else {
let ra = plane[y * width + x - 1];
let rb = plane[(y - 1) * width + x];
let rc = plane[(y - 1) * width + x - 1];
match predictor {
1 => ra,
2 => rb,
3 => rc,
4 => ra.wrapping_add(rb).wrapping_sub(rc),
5 => ra.wrapping_add(rb.wrapping_sub(rc) >> 1),
6 => rb.wrapping_add(ra.wrapping_sub(rc) >> 1),
7 => (ra.wrapping_add(rb)) >> 1,
_ => unreachable!(),
}
};
let px = plane[y * width + x];
let dm = (px.wrapping_sub(pred) & 0xFFFF) as i32;
let dm = if dm >= 0x8000 { dm - 0x10000 } else { dm };
let da = if x == 0 { 0 } else { cur_diff[ci][x - 1] };
let db = prev_diff[ci][x];
encode_lossless_diff(&mut enc, &mut stats[ci], da, db, dm).unwrap();
cur_diff[ci][x] = dm;
}
reset_pred = false;
since_restart += 1;
}
std::mem::swap(&mut prev_diff, &mut cur_diff);
}
out.extend_from_slice(&enc.finish());
out.extend_from_slice(&[0xFF, 0xD9]); out
}
#[test]
fn sof11_gray8_exact_roundtrip_all_predictors() {
let w = 24usize;
let h = 16usize;
let mut plane = vec![0u32; w * h];
for j in 0..h {
for i in 0..w {
plane[j * w + i] =
((i as i32 * 3 + j as i32 * 5 + ((i ^ j) as i32 & 7)) & 0xFF) as u32;
}
}
for predictor in 1..=7u8 {
let data = encode_sof11_jpeg(w, h, &[plane.clone()], 8, predictor, 0, 0, None);
assert!(
data.windows(2).any(|x| x == [0xFF, 0xCB]),
"SOF11 marker missing"
);
let v = decode_jpeg(&data, None).unwrap();
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w);
for j in 0..h {
for i in 0..w {
assert_eq!(
v.planes[0].data[j * w + i] as u32,
plane[j * w + i],
"pred {predictor} mismatch at ({i},{j})"
);
}
}
}
}
#[test]
fn sof11_gray16_exact_roundtrip() {
let w = 16usize;
let h = 12usize;
let mut s = 0x9E37_79B9u32;
let mut plane = vec![0u32; w * h];
for v in plane.iter_mut() {
s ^= s << 13;
s ^= s >> 17;
s ^= s << 5;
*v = s & 0xFFFF;
}
let data = encode_sof11_jpeg(w, h, &[plane.clone()], 16, 4, 0, 0, None);
let v = decode_jpeg(&data, None).unwrap();
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w * 2);
for i in 0..w * h {
let got =
u16::from_le_bytes([v.planes[0].data[i * 2], v.planes[0].data[i * 2 + 1]]) as u32;
assert_eq!(got, plane[i], "mismatch at sample {i}");
}
}
#[test]
fn sof11_rgb8_exact_roundtrip() {
let w = 20usize;
let h = 14usize;
let mut planes = vec![vec![0u32; w * h]; 3];
for j in 0..h {
for i in 0..w {
planes[0][j * w + i] = ((i * 11 + j * 3) & 0xFF) as u32;
planes[1][j * w + i] = ((i * 2 + j * 17) & 0xFF) as u32;
planes[2][j * w + i] = (((i ^ j) * 29) & 0xFF) as u32;
}
}
let data = encode_sof11_jpeg(w, h, &planes, 8, 1, 0, 0, None);
let v = decode_jpeg(&data, None).unwrap();
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w * 3);
for i in 0..w * h {
for (ci, plane) in planes.iter().enumerate() {
assert_eq!(
v.planes[0].data[i * 3 + ci] as u32,
plane[i],
"component {ci} mismatch at sample {i}"
);
}
}
}
#[test]
fn sof11_restart_interval_roundtrip() {
let w = 16usize;
let h = 24usize;
let mut plane = vec![0u32; w * h];
for j in 0..h {
for i in 0..w {
plane[j * w + i] = ((i * 7 + j * 13 + (i & j)) & 0xFF) as u32;
}
}
let data = encode_sof11_jpeg(w, h, &[plane.clone()], 8, 5, 0, (w * 2) as u16, None);
assert!(
data.windows(2)
.any(|x| x[0] == 0xFF && (0xD0..=0xD7).contains(&x[1])),
"no RSTn marker found in the scan"
);
let v = decode_jpeg(&data, None).unwrap();
for i in 0..w * h {
assert_eq!(v.planes[0].data[i] as u32, plane[i], "mismatch at {i}");
}
}
#[test]
fn sof11_dac_conditioning_roundtrip() {
let w = 24usize;
let h = 16usize;
let mut plane = vec![0u32; w * h];
for j in 0..h {
for i in 0..w {
plane[j * w + i] = ((i * 19 + j * 31 + ((i * j) & 15)) & 0xFF) as u32;
}
}
let data = encode_sof11_jpeg(w, h, &[plane.clone()], 8, 2, 0, 0, Some((2, 5)));
let v = decode_jpeg(&data, None).unwrap();
for i in 0..w * h {
assert_eq!(v.planes[0].data[i] as u32, plane[i], "mismatch at {i}");
}
}
#[test]
fn sof11_point_transform_roundtrip() {
let w = 12usize;
let h = 10usize;
let pt = 2u8;
let mut plane = vec![0u32; w * h];
for j in 0..h {
for i in 0..w {
plane[j * w + i] = ((i * 5 + j * 9) & 0x3F) as u32;
}
}
let data = encode_sof11_jpeg(w, h, &[plane.clone()], 8, 1, pt, 0, None);
let v = decode_jpeg(&data, None).unwrap();
for i in 0..w * h {
assert_eq!(
v.planes[0].data[i] as u32,
plane[i] << pt,
"mismatch at {i}"
);
}
}
#[test]
fn sof13_differential_arithmetic_still_rejected() {
let bytes = vec![
0xFF, 0xD8, 0xFF, 0xCD, 0x00, 0x08, 0x08, 0x00, 0x08, 0x00, 0x08, 0x01, 0xFF, 0xD9, ];
let err = decode_jpeg(&bytes, None).expect_err("expected decode error");
assert!(
matches!(err, Error::Unsupported(_)),
"expected Unsupported, got {err:?}"
);
}
#[test]
fn hierarchical_arithmetic_still_rejected() {
let bytes = vec![
0xFF, 0xD8, 0xFF, 0xC5, 0x00, 0x08, 0x08, 0x00, 0x08, 0x00, 0x08, 0x01, 0xFF, 0xD9, ];
let mut dec_params = CodecParameters::video(CodecId::new("mjpeg"));
dec_params.width = Some(8);
dec_params.height = Some(8);
let mut dec = make_decoder(&dec_params).unwrap();
dec.send_packet(&Packet::new(0, TimeBase::new(1, 30), bytes))
.unwrap();
let err = dec.receive_frame().expect_err("expected decode error");
assert!(
matches!(err, oxideav_core::Error::Unsupported(_)),
"expected Unsupported, got {err:?}"
);
}
}
#[cfg(test)]
mod sof10_tests {
use super::decode_jpeg;
use crate::jpeg::arith::{encode_magnitude, AcStats, ArithEncoder, Context, DcStats};
use crate::jpeg::dct::idct8x8;
use crate::jpeg::zigzag::ZIGZAG;
fn put_seg(out: &mut Vec<u8>, marker: u8, payload: &[u8]) {
out.extend_from_slice(&[0xFF, marker]);
out.extend_from_slice(&((payload.len() + 2) as u16).to_be_bytes());
out.extend_from_slice(payload);
}
fn code_fixed_bit(e: &mut ArithEncoder, bit: u8) {
let mut ctx = Context { idx: 0, mps: 0 };
e.code_bit(&mut ctx, bit);
}
fn encode_dc_diff(e: &mut ArithEncoder, dc: &mut DcStats, diff: i32) {
let s0 = dc.dc_context();
if diff == 0 {
e.code_bit(&mut dc.bins[s0], 0);
dc.prev_diff = 0;
return;
}
e.code_bit(&mut dc.bins[s0], 1);
let sign = u8::from(diff < 0);
e.code_bit(&mut dc.bins[s0 + 1], sign);
let sx = s0 + 2 + sign as usize;
encode_magnitude(e, &mut dc.bins, sx, 20, diff.unsigned_abs() - 1);
dc.prev_diff = diff;
}
fn encode_dc_unit(e: &mut ArithEncoder, dc: &mut DcStats, full: i32, ah: u8, al: u8) {
if ah == 0 {
let vt = full >> al; let diff = vt - dc.pred;
encode_dc_diff(e, dc, diff);
dc.pred = vt;
} else {
code_fixed_bit(e, ((full >> al) & 1) as u8);
}
}
fn encode_ac_mag(e: &mut ArithEncoder, bins: &mut [Context], k: usize, kx: u8, sz: u32) {
let s_first = 3 * (k - 1) + 2;
if sz == 0 {
e.code_bit(&mut bins[s_first], 0);
return;
}
e.code_bit(&mut bins[s_first], 1);
if sz < 2 {
e.code_bit(&mut bins[s_first], 0);
return;
}
e.code_bit(&mut bins[s_first], 1);
let mut m = 4u32;
let mut s = if (k as u8) <= kx { 189 } else { 217 };
while sz >= m {
e.code_bit(&mut bins[s], 1);
m <<= 1;
s += 1;
}
e.code_bit(&mut bins[s], 0);
let m_bin = s + 14;
let mut bit = m >> 2;
while bit != 0 {
e.code_bit(&mut bins[m_bin], u8::from(sz & bit != 0));
bit >>= 1;
}
}
fn encode_ac_band(
e: &mut ArithEncoder,
ac: &mut AcStats,
vals: &[i32; 64],
ss: usize,
se: usize,
) {
let mut eob = ss;
for k in ss..=se {
if vals[k] != 0 {
eob = k + 1;
}
}
let mut k = ss;
loop {
let se_bin = 3 * (k - 1);
if k >= eob {
e.code_bit(&mut ac.bins[se_bin], 1);
return;
}
e.code_bit(&mut ac.bins[se_bin], 0);
while vals[k] == 0 {
e.code_bit(&mut ac.bins[3 * (k - 1) + 1], 0);
k += 1;
}
e.code_bit(&mut ac.bins[3 * (k - 1) + 1], 1);
let v = vals[k];
code_fixed_bit(e, u8::from(v < 0));
encode_ac_mag(e, &mut ac.bins, k, ac.kx, v.unsigned_abs() - 1);
if k == se {
return;
}
k += 1;
}
}
fn encode_ac_refine_band(
e: &mut ArithEncoder,
bins: &mut [Context; 189],
full: &[i32; 64],
ss: usize,
se: usize,
al: u8,
) {
let shifted = |k: usize| (full[k].unsigned_abs() >> al) as i32;
let hist = |k: usize| (full[k].unsigned_abs() >> (al + 1)) as i32;
let mut eob = ss;
let mut eobx = ss;
for k in ss..=se {
if shifted(k) != 0 {
eob = k + 1;
}
if hist(k) != 0 {
eobx = k + 1;
}
}
let mut k = ss;
loop {
if k >= eobx {
let se_bin = 3 * (k - 1);
if k >= eob {
e.code_bit(&mut bins[se_bin], 1);
return;
}
e.code_bit(&mut bins[se_bin], 0);
}
loop {
if hist(k) != 0 {
let t = ((full[k].unsigned_abs() >> al) & 1) as u8;
e.code_bit(&mut bins[3 * (k - 1) + 2], t);
break;
}
if shifted(k) != 0 {
e.code_bit(&mut bins[3 * (k - 1) + 1], 1);
code_fixed_bit(e, u8::from(full[k] < 0));
break;
}
e.code_bit(&mut bins[3 * (k - 1) + 1], 0);
k += 1;
}
if k == se {
return;
}
k += 1;
}
}
type ScanDesc = (Vec<usize>, u8, u8, u8, u8);
#[allow(clippy::too_many_arguments)]
fn encode_sof10_jpeg(
width: usize,
height: usize,
precision: u8,
comps: &[(u8, u8)],
blocks: &[Vec<[i32; 64]>],
scans: &[ScanDesc],
restart_interval: u16,
kx_dac: Option<u8>,
) -> Vec<u8> {
let nc = comps.len();
let mut out = vec![0xFF, 0xD8]; let mut dqt = vec![0u8];
dqt.extend(std::iter::repeat(1u8).take(64));
put_seg(&mut out, 0xDB, &dqt);
if let Some(kx) = kx_dac {
put_seg(&mut out, 0xCC, &[0x10, kx]);
}
if restart_interval != 0 {
put_seg(&mut out, 0xDD, &restart_interval.to_be_bytes());
}
let mut sof = vec![precision];
sof.extend_from_slice(&(height as u16).to_be_bytes());
sof.extend_from_slice(&(width as u16).to_be_bytes());
sof.push(nc as u8);
for (ci, (h, v)) in comps.iter().enumerate() {
sof.extend_from_slice(&[ci as u8 + 1, (h << 4) | v, 0]);
}
put_seg(&mut out, 0xCA, &sof);
let h_max = comps.iter().map(|c| c.0).max().unwrap() as usize;
let v_max = comps.iter().map(|c| c.1).max().unwrap() as usize;
let mcus_x = width.div_ceil(8 * h_max);
let mcus_y = height.div_ceil(8 * v_max);
let kx = kx_dac.unwrap_or(5);
for (scomps, ss, se, ah, al) in scans {
let mut sos = vec![scomps.len() as u8];
for &ci in scomps {
sos.extend_from_slice(&[ci as u8 + 1, 0x00]);
}
sos.extend_from_slice(&[*ss, *se, (ah << 4) | al]);
put_seg(&mut out, 0xDA, &sos);
let is_dc = *ss == 0;
let interleaved = is_dc && scomps.len() > 1;
let (sm_x, sm_y) = if interleaved {
(mcus_x, mcus_y)
} else {
let (h, v) = comps[scomps[0]];
(mcus_x * h as usize, mcus_y * v as usize)
};
let mut dc_stats: Vec<DcStats> = (0..scomps.len()).map(|_| DcStats::new()).collect();
let mut ac = AcStats::new();
ac.kx = kx;
let mut refine_bins = [Context::default(); 189];
let mut enc = ArithEncoder::new();
let mut since_restart = 0u32;
let mut rst = 0u8;
for my in 0..sm_y {
for mx in 0..sm_x {
if restart_interval != 0
&& since_restart != 0
&& since_restart % restart_interval as u32 == 0
{
out.extend_from_slice(&std::mem::take(&mut enc).finish());
out.extend_from_slice(&[0xFF, 0xD0 + rst]);
rst = (rst + 1) % 8;
for s in dc_stats.iter_mut() {
*s = DcStats::new();
}
ac.reset();
refine_bins = [Context::default(); 189];
}
if is_dc {
for (sidx, &ci) in scomps.iter().enumerate() {
let (h, v) = comps[ci];
let blocks_x = mcus_x * h as usize;
if interleaved {
for by in 0..v as usize {
for bx in 0..h as usize {
let bi = (my * v as usize + by) * blocks_x
+ mx * h as usize
+ bx;
encode_dc_unit(
&mut enc,
&mut dc_stats[sidx],
blocks[ci][bi][0],
*ah,
*al,
);
}
}
} else {
let bi = my * blocks_x + mx;
encode_dc_unit(
&mut enc,
&mut dc_stats[sidx],
blocks[ci][bi][0],
*ah,
*al,
);
}
}
} else {
let ci = scomps[0];
let (h, _) = comps[ci];
let blocks_x = mcus_x * h as usize;
let bi = my * blocks_x + mx;
let vals = &blocks[ci][bi];
if *ah == 0 {
let mut tv = [0i32; 64];
for k in *ss as usize..=*se as usize {
let m = (vals[k].unsigned_abs() >> al) as i32;
tv[k] = if vals[k] < 0 { -m } else { m };
}
encode_ac_band(&mut enc, &mut ac, &tv, *ss as usize, *se as usize);
} else {
encode_ac_refine_band(
&mut enc,
&mut refine_bins,
vals,
*ss as usize,
*se as usize,
*al,
);
}
}
since_restart += 1;
}
}
out.extend_from_slice(&enc.finish());
}
out.extend_from_slice(&[0xFF, 0xD9]); out
}
struct Rng(u32);
impl Rng {
fn next(&mut self) -> u32 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 17;
self.0 ^= self.0 << 5;
self.0
}
}
fn gen_blocks(
n: usize,
seed: u32,
dc_amp: i32,
ac_amp: i32,
ac_count: usize,
) -> Vec<[i32; 64]> {
let mut rng = Rng(seed);
let mut out = Vec::with_capacity(n);
for _ in 0..n {
let mut b = [0i32; 64];
b[0] = (rng.next() % (2 * dc_amp as u32 + 1)) as i32 - dc_amp;
for _ in 0..ac_count {
let k = 1 + (rng.next() as usize) % 63;
let v = (rng.next() % (2 * ac_amp as u32 + 1)) as i32 - ac_amp;
b[k] = v;
}
out.push(b);
}
out
}
fn expected_plane_8(blocks: &[[i32; 64]], blocks_x: usize, w: usize, h: usize) -> Vec<u8> {
let bw = blocks_x * 8;
let blocks_y = blocks.len() / blocks_x;
let mut full = vec![0u8; bw * blocks_y * 8];
for (bi, vals) in blocks.iter().enumerate() {
let mut nat = [0.0f32; 64];
for k in 0..64 {
nat[ZIGZAG[k]] = vals[k] as f32;
}
idct8x8(&mut nat);
let bx = bi % blocks_x;
let by = bi / blocks_x;
for j in 0..8 {
for i in 0..8 {
let v = nat[j * 8 + i] + 128.0;
let px = if v <= 0.0 {
0
} else if v >= 255.0 {
255
} else {
v.round() as u8
};
full[(by * 8 + j) * bw + bx * 8 + i] = px;
}
}
}
let mut out = vec![0u8; w * h];
for y in 0..h {
out[y * w..y * w + w].copy_from_slice(&full[y * bw..y * bw + w]);
}
out
}
fn expected_plane_12(blocks: &[[i32; 64]], blocks_x: usize, w: usize, h: usize) -> Vec<u16> {
let bw = blocks_x * 8;
let blocks_y = blocks.len() / blocks_x;
let mut full = vec![0u16; bw * blocks_y * 8];
for (bi, vals) in blocks.iter().enumerate() {
let mut nat = [0.0f32; 64];
for k in 0..64 {
nat[ZIGZAG[k]] = vals[k] as f32;
}
idct8x8(&mut nat);
let bx = bi % blocks_x;
let by = bi / blocks_x;
for j in 0..8 {
for i in 0..8 {
let v = nat[j * 8 + i] + 2048.0;
let px = if v <= 0.0 {
0
} else if v >= 4095.0 {
4095
} else {
v.round() as u16
};
full[(by * 8 + j) * bw + bx * 8 + i] = px;
}
}
}
let mut out = vec![0u16; w * h];
for y in 0..h {
out[y * w..y * w + w].copy_from_slice(&full[y * bw..y * bw + w]);
}
out
}
fn assert_gray8_exact(jpeg: &[u8], blocks: &[[i32; 64]], blocks_x: usize, w: usize, h: usize) {
let v = decode_jpeg(jpeg, None).expect("decode SOF10");
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w);
let want = expected_plane_8(blocks, blocks_x, w, h);
for y in 0..h {
for x in 0..w {
assert_eq!(
v.planes[0].data[y * w + x],
want[y * w + x],
"pixel mismatch at ({x},{y})"
);
}
}
}
#[test]
fn sof10_gray8_spectral_selection_roundtrip() {
let (w, h) = (24usize, 16usize);
let blocks = gen_blocks(3 * 2, 0xC0FFEE11, 200, 60, 10);
let scans: Vec<ScanDesc> = vec![
(vec![0], 0, 0, 0, 0),
(vec![0], 1, 5, 0, 0),
(vec![0], 6, 63, 0, 0),
];
let jpeg = encode_sof10_jpeg(
w,
h,
8,
&[(1, 1)],
std::slice::from_ref(&blocks),
&scans,
0,
None,
);
assert!(
jpeg.windows(2).any(|x| x == [0xFF, 0xCA]),
"SOF10 marker missing"
);
assert_gray8_exact(&jpeg, &blocks, 3, w, h);
}
#[test]
fn sof10_gray8_full_progression_roundtrip() {
let (w, h) = (24usize, 16usize);
let blocks = gen_blocks(3 * 2, 0xDEC0DE22, 180, 50, 12);
let scans: Vec<ScanDesc> = vec![
(vec![0], 0, 0, 0, 1),
(vec![0], 1, 5, 0, 1),
(vec![0], 6, 63, 0, 1),
(vec![0], 0, 0, 1, 0),
(vec![0], 1, 5, 1, 0),
(vec![0], 6, 63, 1, 0),
];
let jpeg = encode_sof10_jpeg(
w,
h,
8,
&[(1, 1)],
std::slice::from_ref(&blocks),
&scans,
0,
None,
);
assert_gray8_exact(&jpeg, &blocks, 3, w, h);
}
#[test]
fn sof10_gray8_two_level_sa_roundtrip() {
let (w, h) = (16usize, 16usize);
let blocks = gen_blocks(2 * 2, 0x5EED3333, 120, 40, 14);
let scans: Vec<ScanDesc> = vec![
(vec![0], 0, 0, 0, 2),
(vec![0], 1, 63, 0, 2),
(vec![0], 0, 0, 2, 1),
(vec![0], 1, 63, 2, 1),
(vec![0], 0, 0, 1, 0),
(vec![0], 1, 63, 1, 0),
];
let jpeg = encode_sof10_jpeg(
w,
h,
8,
&[(1, 1)],
std::slice::from_ref(&blocks),
&scans,
0,
None,
);
assert_gray8_exact(&jpeg, &blocks, 2, w, h);
}
#[test]
fn sof10_yuv420_interleaved_dc_roundtrip() {
let (w, h) = (32usize, 16usize);
let comps = [(2u8, 2u8), (1, 1), (1, 1)];
let blocks = vec![
gen_blocks(4 * 2, 0xAAAA0001, 150, 40, 8),
gen_blocks(2, 0xBBBB0002, 100, 30, 6),
gen_blocks(2, 0xCCCC0003, 100, 30, 6),
];
let scans: Vec<ScanDesc> = vec![
(vec![0, 1, 2], 0, 0, 0, 0),
(vec![0], 1, 63, 0, 0),
(vec![1], 1, 63, 0, 0),
(vec![2], 1, 63, 0, 0),
];
let jpeg = encode_sof10_jpeg(w, h, 8, &comps, &blocks, &scans, 0, None);
let v = decode_jpeg(&jpeg, None).expect("decode SOF10 4:2:0");
assert_eq!(v.planes.len(), 3);
let dims = [(w, h, 4usize), (w / 2, h / 2, 2), (w / 2, h / 2, 2)];
for ci in 0..3 {
let (cw, ch, bx) = dims[ci];
assert_eq!(v.planes[ci].stride, cw, "plane {ci} stride");
let want = expected_plane_8(&blocks[ci], bx, cw, ch);
assert_eq!(v.planes[ci].data, want, "plane {ci} samples");
}
}
#[test]
fn sof10_restart_interval_roundtrip() {
let (w, h) = (32usize, 16usize);
let blocks = gen_blocks(4 * 2, 0x12345678, 160, 45, 9);
let scans: Vec<ScanDesc> = vec![
(vec![0], 0, 0, 0, 1),
(vec![0], 1, 63, 0, 1),
(vec![0], 0, 0, 1, 0),
(vec![0], 1, 63, 1, 0),
];
let jpeg = encode_sof10_jpeg(
w,
h,
8,
&[(1, 1)],
std::slice::from_ref(&blocks),
&scans,
3,
None,
);
assert!(
jpeg.windows(2)
.any(|x| x[0] == 0xFF && (0xD0..=0xD7).contains(&x[1])),
"no RSTn marker found in the scan"
);
assert_gray8_exact(&jpeg, &blocks, 4, w, h);
}
#[test]
fn sof10_dac_kx_conditioning_roundtrip() {
let (w, h) = (16usize, 16usize);
let blocks = gen_blocks(2 * 2, 0x0BAD5EED, 140, 50, 16);
let scans: Vec<ScanDesc> = vec![(vec![0], 0, 0, 0, 0), (vec![0], 1, 63, 0, 0)];
let jpeg = encode_sof10_jpeg(
w,
h,
8,
&[(1, 1)],
std::slice::from_ref(&blocks),
&scans,
0,
Some(20),
);
assert_gray8_exact(&jpeg, &blocks, 2, w, h);
}
#[test]
fn sof10_gray12_full_progression_roundtrip() {
let (w, h) = (16usize, 16usize);
let blocks = gen_blocks(2 * 2, 0x600DCAFE, 4000, 900, 10);
let scans: Vec<ScanDesc> = vec![
(vec![0], 0, 0, 0, 1),
(vec![0], 1, 63, 0, 1),
(vec![0], 0, 0, 1, 0),
(vec![0], 1, 63, 1, 0),
];
let jpeg = encode_sof10_jpeg(
w,
h,
12,
&[(1, 1)],
std::slice::from_ref(&blocks),
&scans,
0,
None,
);
let v = decode_jpeg(&jpeg, None).expect("decode SOF10 12-bit");
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w * 2);
let want = expected_plane_12(&blocks, 2, w, h);
for i in 0..w * h {
let got = u16::from_le_bytes([v.planes[0].data[i * 2], v.planes[0].data[i * 2 + 1]]);
assert_eq!(got, want[i], "sample mismatch at {i}");
}
}
#[test]
fn sof10_cmyk_spectral_selection_roundtrip() {
let (w, h) = (16usize, 8usize);
let comps = [(1u8, 1u8); 4];
let blocks: Vec<Vec<[i32; 64]>> = (0..4)
.map(|ci| gen_blocks(2, 0x4444_0000 + ci as u32, 120, 35, 7))
.collect();
let mut scans: Vec<ScanDesc> = vec![(vec![0, 1, 2, 3], 0, 0, 0, 0)];
for ci in 0..4 {
scans.push((vec![ci], 1, 63, 0, 0));
}
let jpeg = encode_sof10_jpeg(w, h, 8, &comps, &blocks, &scans, 0, None);
let v = decode_jpeg(&jpeg, None).expect("decode SOF10 CMYK");
assert_eq!(v.planes.len(), 1);
assert_eq!(v.planes[0].stride, w * 4);
let want: Vec<Vec<u8>> = blocks
.iter()
.map(|b| expected_plane_8(b, 2, w, h))
.collect();
for i in 0..w * h {
for ci in 0..4 {
assert_eq!(
v.planes[0].data[i * 4 + ci],
want[ci][i],
"component {ci} mismatch at sample {i}"
);
}
}
}
#[test]
fn sof10_unsupported_precision_rejected() {
let (w, h) = (8usize, 8usize);
let blocks = gen_blocks(1, 1, 50, 10, 3);
let scans: Vec<ScanDesc> = vec![(vec![0], 0, 0, 0, 0)];
let jpeg = encode_sof10_jpeg(w, h, 10, &[(1, 1)], &[blocks], &scans, 0, None);
let err = decode_jpeg(&jpeg, None).expect_err("expected decode error");
assert!(
matches!(err, crate::error::MjpegError::Unsupported(_)),
"expected Unsupported, got {err:?}"
);
}
}
#[cfg(test)]
mod dnl_unit_tests {
use super::resolve_dnl_height;
use crate::jpeg::markers;
fn put_seg(out: &mut Vec<u8>, marker: u8, payload: &[u8]) {
out.extend_from_slice(&[0xFF, marker]);
out.extend_from_slice(&((payload.len() + 2) as u16).to_be_bytes());
out.extend_from_slice(payload);
}
fn build(y: u16, dnl: Option<u16>) -> Vec<u8> {
let mut out = Vec::new();
let mut sof = vec![8u8];
sof.extend_from_slice(&y.to_be_bytes());
sof.extend_from_slice(&16u16.to_be_bytes());
sof.push(1); sof.extend_from_slice(&[1, 0x11, 0]);
put_seg(&mut out, markers::SOF0, &sof);
put_seg(&mut out, markers::SOS, &[1, 1, 0, 0, 63, 0]);
out.push(0x00);
if let Some(nl) = dnl {
put_seg(&mut out, markers::DNL, &nl.to_be_bytes());
}
out.extend_from_slice(&[0xFF, markers::EOI]);
out
}
#[test]
fn non_zero_y_needs_no_resolution() {
let data = build(16, None);
assert_eq!(resolve_dnl_height(&data).unwrap(), None);
}
#[test]
fn zero_y_with_dnl_resolves() {
let data = build(0, Some(123));
assert_eq!(resolve_dnl_height(&data).unwrap(), Some(123));
}
#[test]
fn zero_y_without_dnl_errors() {
let data = build(0, None);
assert!(resolve_dnl_height(&data).is_err());
}
#[test]
fn zero_y_zero_nl_errors() {
let data = build(0, Some(0));
assert!(resolve_dnl_height(&data).is_err());
}
}