jxl-render 0.12.4

JPEG XL image renderer, part of jxl-oxide
Documentation
use jxl_grid::{MutableSubgrid, SimdVector};

use super::super::dct_common::{self, DctDirection};
use std::arch::aarch64::*;

const LANE_SIZE: usize = 4;
type Lane = float32x4_t;

#[inline(always)]
pub(crate) unsafe fn transpose_lane(lanes: &[Lane]) -> float32x4x4_t {
    assert_eq!(lanes.len(), 4);
    let ptr = lanes.as_ptr() as *mut f32;
    vld4q_f32(ptr as *const _)
}

#[inline(always)]
pub(crate) unsafe fn dct_2d_aarch64_neon(io: &mut MutableSubgrid<'_>, direction: DctDirection) {
    if !io.width().is_multiple_of(LANE_SIZE) || !io.height().is_multiple_of(LANE_SIZE) {
        return super::generic::dct_2d(io, direction);
    }

    let Some(mut io) = io.as_vectored() else {
        tracing::trace!("Input buffer is not aligned");
        return super::generic::dct_2d(io, direction);
    };

    if io.width() == 2 && io.height() == 8 {
        unsafe {
            return dct8x8(&mut io, direction);
        }
    }

    dct_2d_lane(&mut io, direction);
}

fn dct_2d_lane(io: &mut MutableSubgrid<'_, Lane>, direction: DctDirection) {
    let scratch_size = io.height().max(io.width() * LANE_SIZE) * 2;
    unsafe {
        let mut scratch_lanes = vec![Lane::zero(); scratch_size];
        column_dct_lane(io, &mut scratch_lanes, direction);
        row_dct_lane(io, &mut scratch_lanes, direction);
    }
}

#[inline]
unsafe fn dct4_vec_forward(v: Lane) -> Lane {
    const SEC0: f32 = 0.5411961;
    const SEC1: f32 = 1.306563;

    let v01 = vget_low_f32(v);
    let v23 = vget_high_f32(v);
    let addsub = vcombine_f32(
        vadd_f32(v01, vrev64_f32(v23)),
        vfma_n_f32(vrev64_f32(v01), v23, -1f32),
    );

    let addsub3012 = vextq_f32(addsub, addsub, 3);
    let addsub03 = vrev64_f32(vget_low_f32(addsub3012));
    let addsub12 = vget_high_f32(addsub3012);

    let a = vcombine_f32(addsub03, addsub12);
    let mul_a = Lane::set([
        0.25,
        (std::f32::consts::FRAC_1_SQRT_2 / 2.0 + 0.25) * SEC0,
        -0.25,
        -0.25 * SEC1,
    ]);
    let b = vcombine_f32(addsub12, addsub03);
    let mul_b = Lane::set([
        0.25,
        (std::f32::consts::FRAC_1_SQRT_2 / 2.0 - 0.25) * SEC1,
        0.25,
        0.25 * SEC0,
    ]);
    a.muladd(mul_a, b.mul(mul_b))
}

#[inline]
pub(crate) unsafe fn dct4_vec_inverse(v: Lane) -> Lane {
    const SEC0: f32 = 0.5411961;
    const SEC1: f32 = 1.306563;

    let v_flip = vextq_f32(v, v, 2);
    let mul_a = Lane::set([1.0, (std::f32::consts::SQRT_2 + 1.0) * SEC0, -1.0, -SEC1]);
    let mul_b = Lane::set([1.0, SEC0, 1.0, (std::f32::consts::SQRT_2 - 1.0) * SEC1]);
    let tmp = v.muladd(mul_a, v_flip.mul(mul_b));

    let float32x4x2_t(tmp_a, tmp_b) = vuzpq_f32(tmp, vextq_f32(tmp, tmp, 2));
    let mul = vcombine_f32(vdup_n_f32(1.0), vdup_n_f32(-1.0));
    tmp_b.muladd(mul, tmp_a)
}

#[inline]
unsafe fn dct8_vec_forward(vl: Lane, vr: Lane) -> (Lane, Lane) {
    #[allow(clippy::excessive_precision)]
    let sec_vec = Lane::set([
        0.2548977895520796,
        0.30067244346752264,
        0.4499881115682078,
        1.2814577238707527,
    ]);
    let vr_rev = vrev64q_f32(vextq_f32(vr, vr, 2));
    let input0 = vmulq_n_f32(vl.add(vr_rev), 0.5);
    let input1 = vl.sub(vr_rev).mul(sec_vec);
    let output0 = dct4_vec_forward(input0);
    let output1 = dct4_vec_forward(input1);
    let output1_shifted = vextq_f32(output1, Lane::zero(), 1);
    let output1_mul = vsetq_lane_f32(std::f32::consts::SQRT_2, Lane::splat_f32(1.0), 0);
    let output1 = output1.muladd(output1_mul, output1_shifted);
    (vzip1q_f32(output0, output1), vzip2q_f32(output0, output1))
}

#[inline]
pub(crate) unsafe fn dct8_vec_inverse(vl: Lane, vr: Lane) -> (Lane, Lane) {
    #[allow(clippy::excessive_precision)]
    let sec_vec = Lane::set([
        0.5097955791041592,
        0.6013448869350453,
        0.8999762231364156,
        2.5629154477415055,
    ]);
    let float32x4x2_t(input0, input1) = vuzpq_f32(vl, vr);
    let input1_shifted = vextq_f32(Lane::zero(), input1, 3);
    let input1_mul = vsetq_lane_f32(std::f32::consts::SQRT_2, Lane::splat_f32(1.0), 0);
    let input1 = input1.muladd(input1_mul, input1_shifted);
    let output0 = dct4_vec_inverse(input0);
    let output1 = dct4_vec_inverse(input1);
    let output1 = output1.mul(sec_vec);
    let sub = output0.sub(output1);
    (output0.add(output1), vrev64q_f32(vextq_f32(sub, sub, 2)))
}

unsafe fn dct8x8(io: &mut MutableSubgrid<'_, Lane>, direction: DctDirection) {
    let (mut col0, mut col1) = io.split_horizontal(1);

    if direction == DctDirection::Forward {
        dct8_forward(&mut col0);
        dct8_forward(&mut col1);
        for y in 0..8 {
            let row = io.get_row_mut(y);
            let (vl, vr) = dct8_vec_forward(row[0], row[1]);
            row[0] = vl;
            row[1] = vr;
        }
    } else {
        dct8_inverse(&mut col0);
        dct8_inverse(&mut col1);
        for y in 0..8 {
            let row = io.get_row_mut(y);
            let (vl, vr) = dct8_vec_inverse(row[0], row[1]);
            row[0] = vl;
            row[1] = vr;
        }
    }
}

unsafe fn column_dct_lane(
    io: &mut MutableSubgrid<'_, Lane>,
    scratch: &mut [Lane],
    direction: DctDirection,
) {
    let width = io.width();
    let height = io.height();
    let (io_lanes, scratch_lanes) = scratch[..height * 2].split_at_mut(height);
    for x in 0..width {
        for (y, input) in io_lanes.iter_mut().enumerate() {
            *input = io.get(x, y);
        }
        dct(io_lanes, scratch_lanes, direction);
        for (y, output) in io_lanes.chunks_exact(LANE_SIZE).enumerate() {
            let float32x4x4_t(o0, o1, o2, o3) = transpose_lane(output);
            *io.get_mut(x, y * LANE_SIZE) = o0;
            *io.get_mut(x, y * LANE_SIZE + 1) = o1;
            *io.get_mut(x, y * LANE_SIZE + 2) = o2;
            *io.get_mut(x, y * LANE_SIZE + 3) = o3;
        }
    }
}

unsafe fn row_dct_lane(
    io: &mut MutableSubgrid<'_, Lane>,
    scratch: &mut [Lane],
    direction: DctDirection,
) {
    let width = io.width() * LANE_SIZE;
    let height = io.height();
    let (io_lanes, scratch_lanes) = scratch[..width * 2].split_at_mut(width);
    for y in (0..height).step_by(LANE_SIZE) {
        for (x, input) in io_lanes.chunks_exact_mut(LANE_SIZE).enumerate() {
            for (dy, input) in input.iter_mut().enumerate() {
                *input = io.get(x, y + dy);
            }
        }
        dct(io_lanes, scratch_lanes, direction);
        for (x, output) in io_lanes.chunks_exact(LANE_SIZE).enumerate() {
            let float32x4x4_t(o0, o1, o2, o3) = transpose_lane(output);
            *io.get_mut(x, y) = o0;
            *io.get_mut(x, y + 1) = o1;
            *io.get_mut(x, y + 2) = o2;
            *io.get_mut(x, y + 3) = o3;
        }
    }
}

#[inline]
unsafe fn dct4_forward(input: [Lane; 4]) -> [Lane; 4] {
    let sec0 = 0.5411961 / 4.0;
    let sec1 = 1.306563 / 4.0;

    let sum03 = input[0].add(input[3]);
    let sum12 = input[1].add(input[2]);
    let tmp0 = vmulq_n_f32(input[0].sub(input[3]), sec0);
    let tmp1 = vmulq_n_f32(input[1].sub(input[2]), sec1);
    let out0 = tmp0.add(tmp1);
    let out1 = tmp0.sub(tmp1);

    [
        vmulq_n_f32(sum03.add(sum12), 0.25),
        vfmaq_n_f32(out1, out0, std::f32::consts::SQRT_2),
        vmulq_n_f32(sum03.sub(sum12), 0.25),
        out1,
    ]
}

#[inline]
pub(crate) unsafe fn dct4_inverse(input: [Lane; 4]) -> [Lane; 4] {
    let sec0 = 0.5411961;
    let sec1 = 1.306563;

    let tmp0 = vmulq_n_f32(input[1], std::f32::consts::SQRT_2);
    let tmp1 = input[1].add(input[3]);
    let out0 = vmulq_n_f32(tmp0.add(tmp1), sec0);
    let out1 = vmulq_n_f32(tmp0.sub(tmp1), sec1);
    let sum02 = input[0].add(input[2]);
    let sub02 = input[0].sub(input[2]);

    [
        sum02.add(out0),
        sub02.add(out1),
        sub02.sub(out1),
        sum02.sub(out0),
    ]
}

#[inline]
unsafe fn dct8_forward(io: &mut MutableSubgrid<'_, Lane>) {
    assert!(io.height() == 8);
    let sec = dct_common::sec_half_small(8);

    let input0 = [
        vmulq_n_f32(io.get(0, 0).add(io.get(0, 7)), 0.5),
        vmulq_n_f32(io.get(0, 1).add(io.get(0, 6)), 0.5),
        vmulq_n_f32(io.get(0, 2).add(io.get(0, 5)), 0.5),
        vmulq_n_f32(io.get(0, 3).add(io.get(0, 4)), 0.5),
    ];
    let input1 = [
        vmulq_n_f32(io.get(0, 0).sub(io.get(0, 7)), sec[0] / 2.0),
        vmulq_n_f32(io.get(0, 1).sub(io.get(0, 6)), sec[1] / 2.0),
        vmulq_n_f32(io.get(0, 2).sub(io.get(0, 5)), sec[2] / 2.0),
        vmulq_n_f32(io.get(0, 3).sub(io.get(0, 4)), sec[3] / 2.0),
    ];
    let output0 = dct4_forward(input0);
    for (idx, v) in output0.into_iter().enumerate() {
        *io.get_mut(0, idx * 2) = v;
    }
    let mut output1 = dct4_forward(input1);
    output1[0] = vmulq_n_f32(output1[0], std::f32::consts::SQRT_2);
    for idx in 0..3 {
        *io.get_mut(0, idx * 2 + 1) = output1[idx].add(output1[idx + 1]);
    }
    *io.get_mut(0, 7) = output1[3];
}

#[inline]
unsafe fn dct8_inverse(io: &mut MutableSubgrid<'_, Lane>) {
    assert!(io.height() == 8);
    let sec = dct_common::sec_half_small(8);

    let input0 = [io.get(0, 0), io.get(0, 2), io.get(0, 4), io.get(0, 6)];
    let input1 = [
        vmulq_n_f32(io.get(0, 1), std::f32::consts::SQRT_2),
        io.get(0, 3).add(io.get(0, 1)),
        io.get(0, 5).add(io.get(0, 3)),
        io.get(0, 7).add(io.get(0, 5)),
    ];
    let output0 = dct4_inverse(input0);
    let output1 = dct4_inverse(input1);
    for (idx, &sec) in sec.iter().enumerate() {
        let r = vmulq_n_f32(output1[idx], sec);
        *io.get_mut(0, idx) = output0[idx].add(r);
        *io.get_mut(0, 7 - idx) = output0[idx].sub(r);
    }
}

unsafe fn dct(io: &mut [Lane], scratch: &mut [Lane], direction: DctDirection) {
    let n = io.len();
    assert!(scratch.len() == n);

    if n == 0 {
        return;
    }
    if n == 1 {
        return;
    }

    if n == 2 {
        let tmp0 = io[0].add(io[1]);
        let tmp1 = io[0].sub(io[1]);
        if direction == DctDirection::Forward {
            io[0] = vmulq_n_f32(tmp0, 0.5);
            io[1] = vmulq_n_f32(tmp1, 0.5);
        } else {
            io[0] = tmp0;
            io[1] = tmp1;
        }
        return;
    }

    if n == 4 {
        if direction == DctDirection::Forward {
            io.copy_from_slice(&dct4_forward([io[0], io[1], io[2], io[3]]));
        } else {
            io.copy_from_slice(&dct4_inverse([io[0], io[1], io[2], io[3]]));
        }
        return;
    }

    if n == 8 {
        if direction == DctDirection::Forward {
            dct8_forward(&mut MutableSubgrid::from_buf(io, 1, 8, 1));
        } else {
            dct8_inverse(&mut MutableSubgrid::from_buf(io, 1, 8, 1));
        }
        return;
    }

    assert!(n.is_power_of_two());

    if direction == DctDirection::Forward {
        let (input0, input1) = scratch.split_at_mut(n / 2);
        for (idx, &sec) in dct_common::sec_half(n).iter().enumerate() {
            input0[idx] = vmulq_n_f32(io[idx].add(io[n - idx - 1]), 0.5);
            input1[idx] = vmulq_n_f32(io[idx].sub(io[n - idx - 1]), sec / 2.0);
        }
        let (output0, output1) = io.split_at_mut(n / 2);
        dct(input0, output0, DctDirection::Forward);
        dct(input1, output1, DctDirection::Forward);
        for (idx, v) in input0.iter().enumerate() {
            io[idx * 2] = *v;
        }
        input1[0] = vmulq_n_f32(input1[0], std::f32::consts::SQRT_2);
        for idx in 0..(n / 2 - 1) {
            io[idx * 2 + 1] = input1[idx].add(input1[idx + 1]);
        }
        io[n - 1] = input1[n / 2 - 1];
    } else {
        let (input0, input1) = scratch.split_at_mut(n / 2);
        for idx in 1..(n / 2) {
            let idx = n / 2 - idx;
            input0[idx] = io[idx * 2];
            input1[idx] = io[idx * 2 + 1].add(io[idx * 2 - 1]);
        }
        input0[0] = io[0];
        input1[0] = vmulq_n_f32(io[1], std::f32::consts::SQRT_2);
        let (output0, output1) = io.split_at_mut(n / 2);
        dct(input0, output0, DctDirection::Inverse);
        dct(input1, output1, DctDirection::Inverse);
        for (idx, &sec) in dct_common::sec_half(n).iter().enumerate() {
            let r = vmulq_n_f32(input1[idx], sec);
            output0[idx] = input0[idx].add(r);
            output1[n / 2 - idx - 1] = input0[idx].sub(r);
        }
    }
}