j2k-jpeg 0.6.1

JPEG inspect/decode and fallback encode support for j2k
Documentation
// SPDX-License-Identifier: Apache-2.0

use crate::entropy::block::{decode_block_with_activity, CoefficientBlock};
use crate::entropy::sequential::PreparedDecodePlan;
use crate::error::{JpegError, MarkerKind};
use crate::internal::bit_reader::{BitReader, BitReaderSnapshot};

#[derive(Debug, Default)]
pub(crate) struct CpuCheckpointCache {
    pub(crate) checkpoints: Vec<DeviceCheckpoint>,
}

#[derive(Debug, Clone, PartialEq, Eq)]
/// Entropy decoder resume point for device-side JPEG decode.
pub struct DeviceCheckpoint {
    /// MCU index where decoding can resume.
    pub mcu_index: u32,
    /// Byte offset into scan entropy data.
    pub scan_offset: usize,
    /// Buffered entropy bits at the checkpoint.
    pub bit_accumulator: u64,
    /// Number of valid bits in `bit_accumulator`.
    pub bits_buffered: u8,
    /// Previous DC predictor per component.
    pub prev_dc: [i32; 4],
    /// Next expected restart marker index.
    pub expected_rst: u8,
}

pub(crate) fn build_checkpoint_plan(
    plan: &PreparedDecodePlan,
    scan_bytes: &[u8],
    cadence_mcus: u32,
) -> Result<Vec<DeviceCheckpoint>, JpegError> {
    let total_mcus = total_mcus(plan);
    let cadence_mcus = cadence_mcus.max(1);
    let restart_interval = plan
        .restart_interval
        .filter(|&interval| interval > 0)
        .map(u32::from);
    validate_scan_bytes(scan_bytes, restart_interval.is_some())?;

    let reader_bytes = terminated_scan_bytes(scan_bytes);

    let mut checkpoints = Vec::with_capacity(total_mcus as usize);
    let mut br = BitReader::new(&reader_bytes);
    let mut coeff = CoefficientBlock::default();
    let mut prev_dc = [0i32; 4];
    let mut expected_rst = 0u8;
    let mut mcus_since_restart = 0u32;

    checkpoints.push(snapshot_checkpoint(0, &br, prev_dc, expected_rst));

    for mcu_index in 0..total_mcus {
        if mcu_index > 0 {
            if let Some(restart) = restart_interval {
                if mcus_since_restart == restart {
                    let _ = br.ensure_bits(1);
                    let marker = br.take_marker().ok_or(JpegError::UnexpectedEoi {
                        mcu_at: mcu_index,
                        mcu_total: total_mcus,
                    })?;
                    let expected = 0xD0 | expected_rst;
                    if marker != expected {
                        return Err(JpegError::RestartMismatch {
                            offset: br.position(),
                            expected: expected_rst,
                            found: marker,
                        });
                    }
                    expected_rst = (expected_rst + 1) & 0x07;
                    br.reset_at_restart();
                    prev_dc.fill(0);
                    mcus_since_restart = 0;
                    checkpoints.push(snapshot_checkpoint(mcu_index, &br, prev_dc, expected_rst));
                }
            } else if mcu_index.is_multiple_of(cadence_mcus) {
                checkpoints.push(snapshot_checkpoint(mcu_index, &br, prev_dc, expected_rst));
            }
        }

        decode_one_mcu(plan, &mut br, &mut coeff, &mut prev_dc)?;
        mcus_since_restart += 1;
    }

    match br.take_marker() {
        Some(0xd9) | None => {}
        Some(found) => {
            return Err(JpegError::UnexpectedMarker {
                offset: br.position().saturating_sub(2),
                expected: MarkerKind::Eoi,
                found,
            })
        }
    }

    Ok(checkpoints)
}

pub(crate) fn checkpoint_before_mcu(
    plan: &PreparedDecodePlan,
    scan_bytes: &[u8],
    cadence_mcus: u32,
    target_mcu: u32,
    cache: &mut CpuCheckpointCache,
) -> Result<Option<DeviceCheckpoint>, JpegError> {
    if plan.restart_interval.is_some() || target_mcu == 0 {
        return Ok(None);
    }

    let total_mcus = total_mcus(plan);
    let target_mcu = target_mcu.min(total_mcus);
    let cadence_mcus = cadence_mcus.max(1);
    let target_checkpoint_mcu = (target_mcu / cadence_mcus) * cadence_mcus;
    if target_checkpoint_mcu == 0 {
        return Ok(None);
    }

    if cache.checkpoints.is_empty() {
        cache.checkpoints.push(snapshot_checkpoint(
            0,
            &BitReader::new(scan_bytes),
            [0; 4],
            0,
        ));
    }

    let last_mcu = cache
        .checkpoints
        .last()
        .map_or(0, |checkpoint| checkpoint.mcu_index);
    if last_mcu < target_checkpoint_mcu {
        extend_non_restart_checkpoints(
            plan,
            scan_bytes,
            cadence_mcus,
            target_checkpoint_mcu,
            cache,
        )?;
    }

    Ok(cache
        .checkpoints
        .iter()
        .rev()
        .find(|checkpoint| checkpoint.mcu_index > 0 && checkpoint.mcu_index <= target_mcu)
        .cloned())
}

fn extend_non_restart_checkpoints(
    plan: &PreparedDecodePlan,
    scan_bytes: &[u8],
    cadence_mcus: u32,
    target_checkpoint_mcu: u32,
    cache: &mut CpuCheckpointCache,
) -> Result<(), JpegError> {
    let start = cache
        .checkpoints
        .last()
        .cloned()
        .unwrap_or_else(|| snapshot_checkpoint(0, &BitReader::new(scan_bytes), [0; 4], 0));
    let mut br = BitReader::from_snapshot(
        scan_bytes,
        BitReaderSnapshot {
            pos: start.scan_offset,
            acc: start.bit_accumulator,
            bits: start.bits_buffered,
        },
    );
    let mut prev_dc = start.prev_dc;
    let mut coeff = CoefficientBlock::default();
    let mut mcu_index = start.mcu_index;

    while mcu_index < target_checkpoint_mcu {
        decode_one_mcu(plan, &mut br, &mut coeff, &mut prev_dc)?;
        mcu_index += 1;
        if mcu_index.is_multiple_of(cadence_mcus) {
            cache
                .checkpoints
                .push(snapshot_checkpoint(mcu_index, &br, prev_dc, 0));
        }
    }

    Ok(())
}

fn terminated_scan_bytes(scan_bytes: &[u8]) -> Vec<u8> {
    let mut reader_bytes = Vec::with_capacity(scan_bytes.len() + 2);
    reader_bytes.extend_from_slice(scan_bytes);
    if !reader_bytes.ends_with(&[0xff, 0xd9]) {
        if reader_bytes.last() == Some(&0xff) {
            reader_bytes.push(0xd9);
        } else {
            reader_bytes.extend_from_slice(&[0xff, 0xd9]);
        }
    }
    reader_bytes
}

fn validate_scan_bytes(scan_bytes: &[u8], allow_restart_markers: bool) -> Result<(), JpegError> {
    let mut index = 0usize;
    while index < scan_bytes.len() {
        if scan_bytes[index] != 0xff {
            index += 1;
            continue;
        }

        let marker_start = index;
        let next = index + 1;
        if next >= scan_bytes.len() {
            return Ok(());
        }

        match scan_bytes[next] {
            0x00 => index = next + 1,
            0xd0..=0xd7 if allow_restart_markers => index = next + 1,
            0xd9 => return Ok(()),
            found => {
                return Err(JpegError::UnexpectedMarker {
                    offset: marker_start,
                    expected: MarkerKind::Eoi,
                    found,
                })
            }
        }
    }

    Ok(())
}

fn snapshot_checkpoint(
    mcu_index: u32,
    br: &BitReader<'_>,
    prev_dc: [i32; 4],
    expected_rst: u8,
) -> DeviceCheckpoint {
    let snapshot = br.snapshot();
    DeviceCheckpoint {
        mcu_index,
        scan_offset: snapshot.pos,
        bit_accumulator: snapshot.acc,
        bits_buffered: snapshot.bits,
        prev_dc,
        expected_rst,
    }
}

fn decode_one_mcu(
    plan: &PreparedDecodePlan,
    br: &mut BitReader<'_>,
    coeff: &mut CoefficientBlock,
    prev_dc: &mut [i32; 4],
) -> Result<(), JpegError> {
    for component in &plan.components {
        let plane_index = component.output_index;
        for _ in 0..u32::from(component.h) * u32::from(component.v) {
            let _ = decode_block_with_activity(
                br,
                &component.dc_table,
                &component.ac_table,
                &mut prev_dc[plane_index],
                &component.quant,
                coeff,
            )?;
        }
    }
    Ok(())
}

fn total_mcus(plan: &PreparedDecodePlan) -> u32 {
    let mcu_width = u32::from(plan.sampling.max_h) * 8;
    let mcu_height = u32::from(plan.sampling.max_v) * 8;
    let mcus_per_row = plan.dimensions.0.div_ceil(mcu_width);
    let mcu_rows = plan.dimensions.1.div_ceil(mcu_height);
    mcus_per_row.saturating_mul(mcu_rows)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::decoder::Decoder;
    use crate::internal::bit_reader::{BitReader, BitReaderSnapshot};

    #[test]
    fn non_restart_checkpoints_resume_cleanly() {
        let bytes = grayscale_jpeg(24, 24);
        let decoder = Decoder::new(&bytes).expect("decoder");
        let plan = &decoder.plan;
        let scan_bytes = &decoder.bytes[plan.scan_offset..];
        let checkpoints = build_checkpoint_plan(plan, scan_bytes, 1).expect("checkpoints");
        let reader_bytes = terminated_scan_bytes(scan_bytes);

        for pair in checkpoints.windows(2) {
            let mut prev_dc = pair[0].prev_dc;
            let mut coeff = CoefficientBlock::default();
            let mut br = BitReader::from_snapshot(
                &reader_bytes,
                BitReaderSnapshot {
                    pos: pair[0].scan_offset,
                    acc: pair[0].bit_accumulator,
                    bits: pair[0].bits_buffered,
                },
            );

            decode_one_mcu(plan, &mut br, &mut coeff, &mut prev_dc).expect("decode one mcu");
            let resumed =
                snapshot_checkpoint(pair[1].mcu_index, &br, prev_dc, pair[0].expected_rst);

            assert_eq!(resumed.scan_offset, pair[1].scan_offset);
            assert_eq!(resumed.bit_accumulator, pair[1].bit_accumulator);
            assert_eq!(resumed.bits_buffered, pair[1].bits_buffered);
            assert_eq!(resumed.prev_dc, pair[1].prev_dc);
            assert_eq!(resumed.expected_rst, pair[1].expected_rst);
        }
    }

    #[test]
    fn checkpoint_plan_rejects_non_eoi_terminal_marker() {
        let mut bytes = grayscale_jpeg(24, 24);
        let tail = bytes.len() - 1;
        bytes[tail] = 0xe0;

        let decoder = Decoder::new(&bytes).expect("decoder");
        let plan = &decoder.plan;
        let scan_bytes = &decoder.bytes[plan.scan_offset..];
        let err = build_checkpoint_plan(plan, scan_bytes, 1).expect_err("terminal APPn must fail");

        assert!(matches!(
            err,
            JpegError::UnexpectedMarker {
                expected: MarkerKind::Eoi,
                found: 0xe0,
                ..
            }
        ));
    }

    #[test]
    fn lazy_non_restart_checkpoint_extends_to_nearest_cadence_before_target() {
        let bytes = grayscale_jpeg(48, 48);
        let decoder = Decoder::new(&bytes).expect("decoder");
        let plan = &decoder.plan;
        let scan_bytes = &decoder.bytes[plan.scan_offset..];
        let mut cache = CpuCheckpointCache::default();

        let checkpoint = checkpoint_before_mcu(plan, scan_bytes, 4, 17, &mut cache)
            .expect("checkpoint lookup")
            .expect("target beyond one cadence should produce a checkpoint");

        assert_eq!(checkpoint.mcu_index, 16);
        assert!(cache
            .checkpoints
            .iter()
            .any(|checkpoint| checkpoint.mcu_index == 16));
    }

    fn grayscale_jpeg(width: u16, height: u16) -> Vec<u8> {
        let mut bytes = Vec::new();
        bytes.extend_from_slice(&[0xff, 0xd8]);
        bytes.extend_from_slice(&[0xff, 0xdb, 0x00, 67, 0x00]);
        bytes.extend(std::iter::repeat_n(16u8, 64));
        bytes.extend_from_slice(&[
            0xff,
            0xc0,
            0x00,
            11,
            8,
            (height >> 8) as u8,
            height as u8,
            (width >> 8) as u8,
            width as u8,
            1,
            1,
            0x11,
            0,
        ]);
        bytes.extend_from_slice(&[
            0xff, 0xc4, 0x00, 20, 0x00, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        ]);
        bytes.extend_from_slice(&[
            0xff, 0xc4, 0x00, 20, 0x10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        ]);
        bytes.extend_from_slice(&[0xff, 0xda, 0x00, 0x08, 1, 1, 0x00, 0, 63, 0]);

        let mcu_cols = u32::from(width).div_ceil(8);
        let mcu_rows = u32::from(height).div_ceil(8);
        let mcu_count = (mcu_cols * mcu_rows) as usize;
        bytes.extend(std::iter::repeat_n(0x00, mcu_count));

        bytes.extend_from_slice(&[0xff, 0xd9]);
        bytes
    }
}