zenjxl-decoder 0.3.8

High performance Rust implementation of a JPEG XL decoder
Documentation
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

use std::ops::Range;

use jxl_simd::{F32SimdVec, SimdDescriptor, U8SimdVec, U16SimdVec, simd_function};

use crate::{
    api::{Endianness, JxlDataFormat, JxlOutputBuffer},
    render::low_memory_pipeline::row_buffers::RowBuffer,
};

macro_rules! define_run_interleaved {
    ($fn_name:ident, $ty:ty, $vec_trait:ident, $store_fn:ident, $cnt:expr, $($arg:ident),+) => {
        #[inline(always)]
        fn $fn_name<D: SimdDescriptor>(
            d: D,
            $($arg: &[$ty]),+,
            out: &mut [$ty],
        ) -> usize {
            let len = D::$vec_trait::LEN;
            let mut n = 0;
            let limit = [$($arg.len()),+][0];

            {
                let out_chunks = out[..limit * $cnt].chunks_exact_mut(len * $cnt);
                $(let mut $arg = $arg.chunks_exact(len);)+
                for out_chunk in out_chunks {
                    $(let $arg = D::$vec_trait::load(d, $arg.next().unwrap());)+
                    D::$vec_trait::$store_fn($($arg),+, out_chunk);
                    n += len;
                }
            }

            let d256 = d.maybe_downgrade_256bit();
            let len256 = <D::Descriptor256 as SimdDescriptor>::$vec_trait::LEN;
            if len256 < len {
                let out_chunks = out[n * $cnt..limit * $cnt].chunks_exact_mut(len256 * $cnt);
                $(let mut $arg = $arg[n..limit].chunks_exact(len256);)+
                for out_chunk in out_chunks {
                    $(let $arg = <D::Descriptor256 as SimdDescriptor>::$vec_trait::load(d256, $arg.next().unwrap());)+
                    <D::Descriptor256 as SimdDescriptor>::$vec_trait::$store_fn($($arg),+, out_chunk);
                    n += len256;
                }
            }

            let d128 = d.maybe_downgrade_128bit();
            let len128 = <D::Descriptor128 as SimdDescriptor>::$vec_trait::LEN;
            if len128 < len {
                let out_chunks = out[n * $cnt..limit * $cnt].chunks_exact_mut(len128 * $cnt);
                $(let mut $arg = $arg[n..limit].chunks_exact(len128);)+
                for out_chunk in out_chunks {
                    $(let $arg = <D::Descriptor128 as SimdDescriptor>::$vec_trait::load(d128, $arg.next().unwrap());)+
                    <D::Descriptor128 as SimdDescriptor>::$vec_trait::$store_fn($($arg),+, out_chunk);
                    n += len128;
                }
            }

            n
        }
    };
}

define_run_interleaved!(
    run_interleaved_2_f32,
    f32,
    F32Vec,
    store_interleaved_2,
    2,
    a,
    b
);
define_run_interleaved!(
    run_interleaved_3_f32,
    f32,
    F32Vec,
    store_interleaved_3,
    3,
    a,
    b,
    c
);
define_run_interleaved!(
    run_interleaved_4_f32,
    f32,
    F32Vec,
    store_interleaved_4,
    4,
    a,
    b,
    c,
    e
);

simd_function!(
    store_interleaved_f32,
    d: D,
    fn store_interleaved_impl_f32(
        inputs: &[&[f32]],
        output: &mut [f32]
    ) -> usize {
        match inputs.len() {
            2 => run_interleaved_2_f32(d, inputs[0], inputs[1], output),
            3 => run_interleaved_3_f32(d, inputs[0], inputs[1], inputs[2], output),
            4 => run_interleaved_4_f32(d, inputs[0], inputs[1], inputs[2], inputs[3], output),
            _ => 0,
        }
    }
);

define_run_interleaved!(
    run_interleaved_2_u8,
    u8,
    U8Vec,
    store_interleaved_2,
    2,
    a,
    b
);
define_run_interleaved!(
    run_interleaved_3_u8,
    u8,
    U8Vec,
    store_interleaved_3,
    3,
    a,
    b,
    c
);
define_run_interleaved!(
    run_interleaved_4_u8,
    u8,
    U8Vec,
    store_interleaved_4,
    4,
    a,
    b,
    c,
    e
);

simd_function!(
    store_interleaved_u8,
    d: D,
    fn store_interleaved_impl_u8(
        inputs: &[&[u8]],
        output: &mut [u8]
    ) -> usize {
        match inputs.len() {
            2 => run_interleaved_2_u8(d, inputs[0], inputs[1], output),
            3 => run_interleaved_3_u8(d, inputs[0], inputs[1], inputs[2], output),
            4 => run_interleaved_4_u8(d, inputs[0], inputs[1], inputs[2], inputs[3], output),
            _ => 0,
        }
    }
);

define_run_interleaved!(
    run_interleaved_2_u16,
    u16,
    U16Vec,
    store_interleaved_2,
    2,
    a,
    b
);
define_run_interleaved!(
    run_interleaved_3_u16,
    u16,
    U16Vec,
    store_interleaved_3,
    3,
    a,
    b,
    c
);
define_run_interleaved!(
    run_interleaved_4_u16,
    u16,
    U16Vec,
    store_interleaved_4,
    4,
    a,
    b,
    c,
    e
);

simd_function!(
    store_interleaved_u16,
    d: D,
    fn store_interleaved_impl_u16(
        inputs: &[&[u16]],
        output: &mut [u16]
    ) -> usize {
        match inputs.len() {
            2 => run_interleaved_2_u16(d, inputs[0], inputs[1], output),
            3 => run_interleaved_3_u16(d, inputs[0], inputs[1], inputs[2], output),
            4 => run_interleaved_4_u16(d, inputs[0], inputs[1], inputs[2], inputs[3], output),
            _ => 0,
        }
    }
);

pub(super) fn store(
    input_buf: &[&RowBuffer],
    input_y: usize,
    xrange: Range<usize>,
    output_buf: &mut JxlOutputBuffer,
    output_y: usize,
    data_format: JxlDataFormat,
) -> usize {
    let byte_start = xrange.start * data_format.bytes_per_sample() + RowBuffer::x0_byte_offset();
    let byte_end = xrange.end * data_format.bytes_per_sample() + RowBuffer::x0_byte_offset();
    let is_native_endian = match data_format {
        JxlDataFormat::U8 { .. } => true,
        JxlDataFormat::F16 { endianness, .. }
        | JxlDataFormat::U16 { endianness, .. }
        | JxlDataFormat::F32 { endianness, .. } => endianness == Endianness::native(),
    };
    let output_buf = output_buf.row_mut(output_y);
    let output_buf = &mut output_buf[0..(byte_end - byte_start) * input_buf.len()];
    match (
        input_buf.len(),
        data_format.bytes_per_sample(),
        is_native_endian,
    ) {
        (1, _, true) => {
            // We can just do a memcpy.
            let input_buf = &input_buf[0].get_row::<u8>(input_y)[byte_start..byte_end];
            assert_eq!(input_buf.len(), output_buf.len());
            output_buf.copy_from_slice(input_buf);
            input_buf.len() / data_format.bytes_per_sample()
        }
        (channels, 1, true) if (2..=4).contains(&channels) => {
            let start_u8 = byte_start;
            let end_u8 = byte_end;
            let mut slices = [&[] as &[u8]; 4];
            for (i, buf) in input_buf.iter().enumerate() {
                slices[i] = &buf.get_row::<u8>(input_y)[start_u8..end_u8];
            }
            store_interleaved_u8(&slices[..channels], output_buf)
        }
        (channels, 2, true) if (2..=4).contains(&channels) => {
            if let Ok(output_u16) = bytemuck::try_cast_slice_mut::<u8, u16>(output_buf) {
                let start_u16 = byte_start / 2;
                let end_u16 = byte_end / 2;
                let mut slices = [&[] as &[u16]; 4];
                for (i, buf) in input_buf.iter().enumerate() {
                    slices[i] = &buf.get_row::<u16>(input_y)[start_u16..end_u16];
                }
                store_interleaved_u16(&slices[..channels], output_u16)
            } else {
                0
            }
        }
        (channels, 4, true) if (2..=4).contains(&channels) => {
            if let Ok(output_f32) = bytemuck::try_cast_slice_mut::<u8, f32>(output_buf) {
                let start_f32 = byte_start / 4;
                let end_f32 = byte_end / 4;
                let mut slices = [&[] as &[f32]; 4];
                for (i, buf) in input_buf.iter().enumerate() {
                    slices[i] = &buf.get_row::<f32>(input_y)[start_f32..end_f32];
                }
                store_interleaved_f32(&slices[..channels], output_f32)
            } else {
                0
            }
        }
        _ => 0,
    }
}