compeg 0.4.0

A JPEG decoder implemented as a WebGPU compute shader
Documentation
//! Scan Data preprocessing.
//!
//! The scan data in a JPEG image is uploaded to GPU memory and then decoded by a compute shader.
//! But before that can happen, this module performs some preprocessing on the data. This
//! preprocessing involves:
//!
//! - Finding `RST` markers and saving their locations, so that the compute shader invocations know
//!   where to start.
//! - Finding `0xFF 0x00` byte stuffing sequences and replacing them with `0xFF`.
//! - Aligning the start of each restart interval with a `u32` boundary, so that the shader does
//!   not have to perform byte-wise alignment (WGSL does not have `u8`).

use crate::error::Error;

pub struct ScanBuffer {
    /// Data is stored as `u32`s because that's WebGPU's smallest native integer type. Each
    /// restart interval begins on a word boundary, but the contained bytes are stored packed. This
    /// allows each shader invocation to load a `u32` from memory and start decoding it immediately,
    /// rather than having to shift bytes around first.
    words: Vec<u32>,
    /// Start offsets of the restart intervals in `words`.
    start_positions: Vec<u32>,
}

impl ScanBuffer {
    pub fn new() -> Self {
        Self {
            words: Vec::new(),
            start_positions: Vec::new(),
        }
    }

    pub fn process(
        &mut self,
        scan_data: &[u8],
        expected_restart_intervals: u32,
    ) -> crate::Result<()> {
        // The RST markers are removed from `scan_data`, but each restart interval is padded to
        // start on a 32-bit boundary. That means at worst a 1-byte restart interval preceded by a
        // 2-byte RST marker will occupy one word and waste 3 bytes in there.
        // This is 1 more byte than the input had, so we have to allocate an extra 1/3rd of the
        // input data length in the output buffer.
        let out_bytes = scan_data.len() + scan_data.len() / 3;
        self.words.resize((out_bytes + 3) / 4, 0);

        let start_pos_buffer_length = (expected_restart_intervals as usize).next_power_of_two();
        let start_pos_index_mask = start_pos_buffer_length - 1;
        self.start_positions.resize(start_pos_buffer_length, 0);
        assert!(self.start_positions.len() > start_pos_index_mask);

        let out: &mut [u8] = bytemuck::cast_slice_mut(&mut self.words);

        let res = preprocess_scalar(out, &mut self.start_positions, scan_data);

        self.words.truncate((res.bytes_out + 3) / 4);
        self.start_positions.truncate(res.ri);

        if res.ri != expected_restart_intervals as usize {
            return Err(Error::from(format!(
                "restart interval count mismatch: counted {}, expected {}",
                res.ri, expected_restart_intervals
            )));
        }

        Ok(())
    }

    /// Returns the preprocessed scan data, ready for upload to GPU memory.
    pub fn processed_scan_data(&self) -> &[u8] {
        bytemuck::cast_slice(&self.words)
    }

    /// Returns the computed start positions, ready for upload to GPU memory.
    pub fn start_positions(&self) -> &[u8] {
        bytemuck::cast_slice(&self.start_positions)
    }
}

struct PreprocessResult {
    ri: usize,
    bytes_out: usize,
}

#[inline(never)]
fn preprocess_scalar(
    out: &mut [u8],
    start_positions: &mut [u32],
    scan_data: &[u8],
) -> PreprocessResult {
    assert!(start_positions.len().is_power_of_two());
    let start_pos_mask = start_positions.len() - 1;

    let mut ri = 1; // One at index 0 is already written.
    let mut write_ptr = 0;
    let mut bytes = scan_data.iter().copied();
    loop {
        match bytes.next() {
            Some(0xff) => match bytes.next() {
                Some(0x00) => {
                    // Byte stuffing sequence, push only `0xFF` to the output.
                    out[write_ptr] = 0xff;
                    write_ptr += 1;
                }
                Some(_) => {
                    // RST marker. We don't check the exact marker type to improve perf. Only
                    // RST is valid here.

                    // Align the next restart interval on a 4-byte boundary.
                    write_ptr = (write_ptr + 0b11) & !0b11;

                    start_positions[ri & start_pos_mask] = (write_ptr / 4) as u32;
                    ri += 1;
                }
                None => break,
            },
            Some(byte) => {
                out[write_ptr] = byte;
                write_ptr += 1;
            }
            None => break,
        }
    }

    PreprocessResult {
        ri,
        bytes_out: write_ptr,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn check(scan_data: &[u8], output: &[u8], start_positions: &[u32]) {
        let mut buf = ScanBuffer::new();
        buf.process(scan_data, start_positions.len() as u32)
            .unwrap();

        let bytes: &[u8] = bytemuck::cast_slice(&buf.words);
        assert_eq!(output, bytes);

        assert_eq!(buf.start_positions, start_positions);
    }

    fn check_err(scan_data: &[u8], start_positions: &[u32]) -> Error {
        let mut buf = ScanBuffer::new();
        buf.process(scan_data, start_positions.len() as u32)
            .unwrap_err()
    }

    #[test]
    fn process_scan_data() {
        check(&[0x12, 0x34, 0x56, 0x78], &[0x12, 0x34, 0x56, 0x78], &[0]);

        check(&[0xFF, 0xD0, 0xFF, 0xD0], &[], &[0, 0, 0]);

        let scan = &[0xFF, 0x00, 0x44, 0x55, 0xFF, 0xD0, 0x34];
        check(scan, &[0xFF, 0x44, 0x55, 0x00, 0x34, 0, 0, 0], &[0, 1]);
    }

    #[test]
    fn test_expanding_output() {
        // 3 bytes of input data get expanded to 4 bytes of output data. This tests that we allocate
        // enough space in the output buffer.
        check(
            &[0x11, 0xFF, 0xD0, 0x11, 0xFF, 0xD0, 0x11],
            &[0x11, 0, 0, 0, 0x11, 0, 0, 0, 0x11, 0, 0, 0],
            &[0, 1, 2],
        );
    }

    #[test]
    fn test_too_many_rst_markers() {
        // Data corruption can make the actual number of RST markers exceed the expected number.
        let err = check_err(&[0x11, 0xFF, 0xD0, 0x11, 0xFF, 0xD0, 0x11], &[0]);
        assert_eq!(
            err.to_string(),
            "restart interval count mismatch: counted 3, expected 1"
        );
    }
}