j2k-metal 0.6.1

Apple Metal GPU adapter for Rust JPEG 2000 and HTJ2K decode/encode paths
Documentation
// SPDX-License-Identifier: Apache-2.0

#[cfg(target_os = "macos")]
use crate::compute;
use j2k_native::{
    decode_ht_code_block_scalar, HtCodeBlockDecodeJob, HtCodeBlockDecoder, J2kInverseMctJob, Result,
};
#[cfg(target_os = "macos")]
use metal::Buffer;

#[derive(Default)]
pub(crate) struct MetalMctDecoder {
    #[cfg(target_os = "macos")]
    kernel_dispatches: usize,
    #[cfg(target_os = "macos")]
    captured_planes: Vec<Buffer>,
}

impl MetalMctDecoder {
    #[cfg(all(test, target_os = "macos"))]
    pub(crate) fn kernel_dispatches(&self) -> usize {
        self.kernel_dispatches
    }

    #[cfg(target_os = "macos")]
    pub(crate) fn take_captured_planes(&mut self) -> Vec<Buffer> {
        core::mem::take(&mut self.captured_planes)
    }
}

impl HtCodeBlockDecoder for MetalMctDecoder {
    fn decode_inverse_mct(&mut self, job: J2kInverseMctJob<'_>) -> Result<bool> {
        #[cfg(target_os = "macos")]
        if supports_metal_inverse_mct(&job) {
            self.captured_planes = compute::decode_inverse_mct(job)
                .map_err(|_| j2k_native::DecodingError::CodeBlockDecodeFailure)?;
            self.kernel_dispatches = self.kernel_dispatches.saturating_add(1);
            return Ok(true);
        }
        #[cfg(not(target_os = "macos"))]
        let _ = job;

        Ok(false)
    }

    fn decode_code_block(
        &mut self,
        job: HtCodeBlockDecodeJob<'_>,
        output: &mut [f32],
    ) -> j2k_native::Result<()> {
        decode_ht_code_block_scalar(job, output)
    }
}

#[cfg(target_os = "macos")]
fn supports_metal_inverse_mct(job: &J2kInverseMctJob<'_>) -> bool {
    let len = job.plane0.len();
    len > 0 && job.plane1.len() == len && job.plane2.len() == len
}

#[cfg(test)]
mod tests {
    use super::MetalMctDecoder;
    use j2k_native::{
        encode, DecodeSettings, DecoderContext, EncodeOptions, HtCodeBlockDecodeJob,
        HtCodeBlockDecoder, Image,
    };

    fn fixture_j2k_rgb8() -> Vec<u8> {
        let pixels = [10u8, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120];
        let options = EncodeOptions {
            reversible: true,
            num_decomposition_levels: 1,
            ..EncodeOptions::default()
        };
        encode(&pixels, 2, 2, 3, 8, false, &options).expect("encode classic rgb8")
    }

    fn fixture_j2k_rgb8_irreversible() -> Vec<u8> {
        let pixels = [10u8, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120];
        let options = EncodeOptions {
            reversible: false,
            num_decomposition_levels: 1,
            ..EncodeOptions::default()
        };
        encode(&pixels, 2, 2, 3, 8, false, &options).expect("encode irreversible rgb8")
    }

    #[test]
    fn metal_mct_decoder_matches_native_decode() {
        let bytes = fixture_j2k_rgb8();
        let image = Image::new(&bytes, &DecodeSettings::default()).expect("image");
        let mut expected_context = DecoderContext::default();
        let expected = image
            .decode_components_with_context(&mut expected_context)
            .expect("native decode");

        let mut hooked_context = DecoderContext::default();
        let mut decoder = MetalMctDecoder::default();
        let actual = image
            .decode_components_with_ht_decoder(&mut hooked_context, &mut decoder)
            .expect("hooked decode");

        assert_eq!(actual.dimensions(), expected.dimensions());
        assert_eq!(actual.planes().len(), expected.planes().len());
        for (actual_plane, expected_plane) in actual.planes().iter().zip(expected.planes().iter()) {
            assert_eq!(
                actual_plane.samples(),
                expected_plane.samples(),
                "Metal MCT output must match native decode"
            );
        }
        #[cfg(target_os = "macos")]
        assert!(
            decoder.kernel_dispatches() > 0,
            "RGB fixture must exercise the Metal MCT kernel"
        );
    }

    struct CpuOnlyCodeBlockDecoder;

    impl HtCodeBlockDecoder for CpuOnlyCodeBlockDecoder {
        fn decode_code_block(
            &mut self,
            job: HtCodeBlockDecodeJob<'_>,
            output: &mut [f32],
        ) -> j2k_native::Result<()> {
            j2k_native::decode_ht_code_block_scalar(job, output)
        }
    }

    #[test]
    fn default_decoder_without_mct_kernel_still_decodes() {
        let bytes = fixture_j2k_rgb8();
        let image = Image::new(&bytes, &DecodeSettings::default()).expect("image");
        let mut context = DecoderContext::default();
        let mut decoder = CpuOnlyCodeBlockDecoder;
        let image_components = image
            .decode_components_with_ht_decoder(&mut context, &mut decoder)
            .expect("decode without mct override");
        assert_eq!(image_components.dimensions(), (2, 2));
    }

    #[test]
    fn metal_mct_decoder_matches_native_decode_for_irreversible_rgb() {
        let bytes = fixture_j2k_rgb8_irreversible();
        let image = Image::new(&bytes, &DecodeSettings::default()).expect("image");
        let mut expected_context = DecoderContext::default();
        let expected = image
            .decode_components_with_context(&mut expected_context)
            .expect("native decode");

        let mut hooked_context = DecoderContext::default();
        let mut decoder = MetalMctDecoder::default();
        let actual = image
            .decode_components_with_ht_decoder(&mut hooked_context, &mut decoder)
            .expect("hooked decode");

        assert_eq!(actual.dimensions(), expected.dimensions());
        for (actual_plane, expected_plane) in actual.planes().iter().zip(expected.planes().iter()) {
            assert_eq!(
                actual_plane.samples(),
                expected_plane.samples(),
                "Metal MCT output must match native decode for irreversible RGB images"
            );
        }
        #[cfg(target_os = "macos")]
        assert!(
            decoder.kernel_dispatches() > 0,
            "irreversible RGB fixture must exercise the Metal MCT kernel"
        );
    }

    #[test]
    fn metal_mct_decoder_captures_final_rgb_planes_matching_host_output() {
        let bytes = fixture_j2k_rgb8();
        let image = Image::new(&bytes, &DecodeSettings::default()).expect("image");
        let mut context = DecoderContext::default();
        let mut decoder = MetalMctDecoder::default();
        let components = image
            .decode_components_with_ht_decoder(&mut context, &mut decoder)
            .expect("hooked decode");
        #[cfg(not(target_os = "macos"))]
        let _ = components;

        #[cfg(target_os = "macos")]
        {
            let captured = decoder.take_captured_planes();
            assert_eq!(captured.len(), components.planes().len());
            for (plane, buffer) in components.planes().iter().zip(captured.iter()) {
                // SAFETY: SIMD lane captures operate on fixed-width stack buffers.
                let captured = unsafe {
                    core::slice::from_raw_parts(
                        buffer.contents().cast::<f32>(),
                        plane.samples().len(),
                    )
                };
                assert_eq!(
                    captured,
                    plane.samples(),
                    "captured Metal MCT planes must match final decoded RGB planes"
                );
            }
        }
    }
}