jxl-render 0.12.4

JPEG XL image renderer, part of jxl-oxide
Documentation
use jxl_grid::{MutableSubgrid, SimdVector};

use super::super::dct_common::{self, DctDirection};
use std::arch::wasm32::*;

const LANE_SIZE: usize = 4;
type Lane = v128;

#[inline(always)]
pub(crate) fn transpose_lane(lanes: &[Lane]) -> [Lane; 4] {
    assert_eq!(lanes.len(), 4);
    let mut out = [
        i32x4_shuffle::<0, 1, 4, 5>(lanes[0], lanes[1]),
        i32x4_shuffle::<0, 1, 4, 5>(lanes[2], lanes[3]),
        i32x4_shuffle::<2, 3, 6, 7>(lanes[0], lanes[1]),
        i32x4_shuffle::<2, 3, 6, 7>(lanes[2], lanes[3]),
    ];
    let a = i32x4_shuffle::<0, 2, 4, 6>(out[0], out[1]);
    let b = i32x4_shuffle::<1, 3, 5, 7>(out[0], out[1]);
    out[0] = a;
    out[1] = b;
    let a = i32x4_shuffle::<0, 2, 4, 6>(out[2], out[3]);
    let b = i32x4_shuffle::<1, 3, 5, 7>(out[2], out[3]);
    out[2] = a;
    out[3] = b;
    out
}

#[inline(always)]
pub(crate) unsafe fn dct_2d_wasm32_simd128(io: &mut MutableSubgrid<'_>, direction: DctDirection) {
    if io.width() % LANE_SIZE != 0 || io.height() % LANE_SIZE != 0 {
        return super::generic::dct_2d(io, direction);
    }

    let Some(mut io) = io.as_vectored() else {
        tracing::trace!("Input buffer is not aligned");
        return super::generic::dct_2d(io, direction);
    };

    if io.width() == 2 && io.height() == 8 {
        unsafe {
            return dct8x8(&mut io, direction);
        }
    }

    dct_2d_lane(&mut io, direction);
}

fn dct_2d_lane(io: &mut MutableSubgrid<'_, Lane>, direction: DctDirection) {
    let scratch_size = io.height().max(io.width() * LANE_SIZE) * 2;
    unsafe {
        let mut scratch_lanes = vec![Lane::zero(); scratch_size];
        column_dct_lane(io, &mut scratch_lanes, direction);
        row_dct_lane(io, &mut scratch_lanes, direction);
    }
}

#[inline]
unsafe fn dct4_vec_forward(v: Lane) -> Lane {
    const SEC0: f32 = 0.5411961;
    const SEC1: f32 = 1.306563;

    let vrev = i32x4_shuffle::<3, 2, 1, 0>(v, v);
    let vneg = f32x4_neg(v);
    let vadd = i32x4_shuffle::<0, 1, 6, 7>(v, vneg);
    let addsub = vrev.add(vadd);

    let a = i32x4_shuffle::<0, 3, 1, 2>(addsub, addsub);
    let mul_a = Lane::set([
        0.25,
        (std::f32::consts::FRAC_1_SQRT_2 / 2.0 + 0.25) * SEC0,
        -0.25,
        -0.25 * SEC1,
    ]);
    let b = i32x4_shuffle::<1, 2, 0, 3>(addsub, addsub);
    let mul_b = Lane::set([
        0.25,
        (std::f32::consts::FRAC_1_SQRT_2 / 2.0 - 0.25) * SEC1,
        0.25,
        0.25 * SEC0,
    ]);
    a.muladd(mul_a, b.mul(mul_b))
}

#[inline]
pub(crate) unsafe fn dct4_vec_inverse(v: Lane) -> Lane {
    const SEC0: f32 = 0.5411961;
    const SEC1: f32 = 1.306563;

    let v_flip = i32x4_shuffle::<2, 3, 0, 1>(v, v);
    let mul_a = Lane::set([1.0, (std::f32::consts::SQRT_2 + 1.0) * SEC0, -1.0, -SEC1]);
    let mul_b = Lane::set([1.0, SEC0, 1.0, (std::f32::consts::SQRT_2 - 1.0) * SEC1]);
    let tmp = v.muladd(mul_a, v_flip.mul(mul_b));

    let tmp_neg = f32x4_neg(tmp);
    let tmp_a = i32x4_shuffle::<0, 2, 2, 0>(tmp, tmp);
    let tmp_b = i32x4_shuffle::<1, 3, 7, 5>(tmp, tmp_neg);
    tmp_a.add(tmp_b)
}

#[inline]
unsafe fn dct8_vec_forward(vl: Lane, vr: Lane) -> (Lane, Lane) {
    #[allow(clippy::excessive_precision)]
    let sec_vec = Lane::set([
        0.2548977895520796,
        0.30067244346752264,
        0.4499881115682078,
        1.2814577238707527,
    ]);
    let vr_rev = i32x4_shuffle::<3, 2, 1, 0>(vr, vr);
    let input0 = vl.add(vr_rev).mul(Lane::splat_f32(0.5));
    let input1 = vl.sub(vr_rev).mul(sec_vec);
    let output0 = dct4_vec_forward(input0);
    let output1 = dct4_vec_forward(input1);
    let output1_shifted = i32x4_shuffle::<1, 2, 3, 4>(output1, Lane::zero());
    let output1_mul = Lane::set([std::f32::consts::SQRT_2, 1.0, 1.0, 1.0]);
    let output1 = output1.muladd(output1_mul, output1_shifted);
    (
        i32x4_shuffle::<0, 4, 1, 5>(output0, output1),
        i32x4_shuffle::<2, 6, 3, 7>(output0, output1),
    )
}

#[inline]
pub(crate) unsafe fn dct8_vec_inverse(vl: Lane, vr: Lane) -> (Lane, Lane) {
    #[allow(clippy::excessive_precision)]
    let sec_vec = Lane::set([
        0.5097955791041592,
        0.6013448869350453,
        0.8999762231364156,
        2.5629154477415055,
    ]);
    let input0 = i32x4_shuffle::<0, 2, 4, 6>(vl, vr);
    let input1 = i32x4_shuffle::<1, 3, 5, 7>(vl, vr);
    let input1_shifted = i32x4_shuffle::<3, 4, 5, 6>(Lane::zero(), input1);
    let input1_mul = Lane::set([std::f32::consts::SQRT_2, 1.0, 1.0, 1.0]);
    let input1 = input1.muladd(input1_mul, input1_shifted);
    let output0 = dct4_vec_inverse(input0);
    let output1 = dct4_vec_inverse(input1);
    let output1 = output1.mul(sec_vec);
    let sub = output0.sub(output1);
    (output0.add(output1), i32x4_shuffle::<3, 2, 1, 0>(sub, sub))
}

unsafe fn dct8x8(io: &mut MutableSubgrid<'_, Lane>, direction: DctDirection) {
    let (mut col0, mut col1) = io.split_horizontal(1);

    if direction == DctDirection::Forward {
        dct8_forward(&mut col0);
        dct8_forward(&mut col1);
        for y in 0..8 {
            let row = io.get_row_mut(y);
            let (vl, vr) = dct8_vec_forward(row[0], row[1]);
            row[0] = vl;
            row[1] = vr;
        }
    } else {
        dct8_inverse(&mut col0);
        dct8_inverse(&mut col1);
        for y in 0..8 {
            let row = io.get_row_mut(y);
            let (vl, vr) = dct8_vec_inverse(row[0], row[1]);
            row[0] = vl;
            row[1] = vr;
        }
    }
}

unsafe fn column_dct_lane(
    io: &mut MutableSubgrid<'_, Lane>,
    scratch: &mut [Lane],
    direction: DctDirection,
) {
    let width = io.width();
    let height = io.height();
    let (io_lanes, scratch_lanes) = scratch[..height * 2].split_at_mut(height);
    for x in 0..width {
        for (y, input) in io_lanes.iter_mut().enumerate() {
            *input = io.get(x, y);
        }
        dct(io_lanes, scratch_lanes, direction);
        for (y, output) in io_lanes.chunks_exact(LANE_SIZE).enumerate() {
            let [o0, o1, o2, o3] = transpose_lane(output);
            *io.get_mut(x, y * LANE_SIZE) = o0;
            *io.get_mut(x, y * LANE_SIZE + 1) = o1;
            *io.get_mut(x, y * LANE_SIZE + 2) = o2;
            *io.get_mut(x, y * LANE_SIZE + 3) = o3;
        }
    }
}

unsafe fn row_dct_lane(
    io: &mut MutableSubgrid<'_, Lane>,
    scratch: &mut [Lane],
    direction: DctDirection,
) {
    let width = io.width() * LANE_SIZE;
    let height = io.height();
    let (io_lanes, scratch_lanes) = scratch[..width * 2].split_at_mut(width);
    for y in (0..height).step_by(LANE_SIZE) {
        for (x, input) in io_lanes.chunks_exact_mut(LANE_SIZE).enumerate() {
            for (dy, input) in input.iter_mut().enumerate() {
                *input = io.get(x, y + dy);
            }
        }
        dct(io_lanes, scratch_lanes, direction);
        for (x, output) in io_lanes.chunks_exact(LANE_SIZE).enumerate() {
            let [o0, o1, o2, o3] = transpose_lane(output);
            *io.get_mut(x, y) = o0;
            *io.get_mut(x, y + 1) = o1;
            *io.get_mut(x, y + 2) = o2;
            *io.get_mut(x, y + 3) = o3;
        }
    }
}

#[inline]
unsafe fn dct4_forward(input: [Lane; 4]) -> [Lane; 4] {
    let sec0 = Lane::splat_f32(0.5411961 / 4.0);
    let sec1 = Lane::splat_f32(1.306563 / 4.0);

    let sum03 = input[0].add(input[3]);
    let sum12 = input[1].add(input[2]);
    let tmp0 = input[0].sub(input[3]).mul(sec0);
    let tmp1 = input[1].sub(input[2]).mul(sec1);
    let out0 = tmp0.add(tmp1);
    let out1 = tmp0.sub(tmp1);

    [
        sum03.add(sum12).mul(Lane::splat_f32(0.25)),
        out0.mul(Lane::splat_f32(std::f32::consts::SQRT_2))
            .add(out1),
        sum03.sub(sum12).mul(Lane::splat_f32(0.25)),
        out1,
    ]
}

#[inline]
pub(crate) unsafe fn dct4_inverse(input: [Lane; 4]) -> [Lane; 4] {
    let sec0 = Lane::splat_f32(0.5411961);
    let sec1 = Lane::splat_f32(1.306563);

    let tmp0 = input[1].mul(Lane::splat_f32(std::f32::consts::SQRT_2));
    let tmp1 = input[1].add(input[3]);
    let out0 = tmp0.add(tmp1).mul(sec0);
    let out1 = tmp0.sub(tmp1).mul(sec1);
    let sum02 = input[0].add(input[2]);
    let sub02 = input[0].sub(input[2]);

    [
        sum02.add(out0),
        sub02.add(out1),
        sub02.sub(out1),
        sum02.sub(out0),
    ]
}

#[inline]
unsafe fn dct8_forward(io: &mut MutableSubgrid<'_, Lane>) {
    assert!(io.height() == 8);
    let sec = dct_common::sec_half_small(8);

    let half = Lane::splat_f32(0.5);
    let input0 = [
        io.get(0, 0).add(io.get(0, 7)).mul(half),
        io.get(0, 1).add(io.get(0, 6)).mul(half),
        io.get(0, 2).add(io.get(0, 5)).mul(half),
        io.get(0, 3).add(io.get(0, 4)).mul(half),
    ];
    let input1 = [
        io.get(0, 0)
            .sub(io.get(0, 7))
            .mul(Lane::splat_f32(sec[0] / 2.0)),
        io.get(0, 1)
            .sub(io.get(0, 6))
            .mul(Lane::splat_f32(sec[1] / 2.0)),
        io.get(0, 2)
            .sub(io.get(0, 5))
            .mul(Lane::splat_f32(sec[2] / 2.0)),
        io.get(0, 3)
            .sub(io.get(0, 4))
            .mul(Lane::splat_f32(sec[3] / 2.0)),
    ];
    let output0 = dct4_forward(input0);
    for (idx, v) in output0.into_iter().enumerate() {
        *io.get_mut(0, idx * 2) = v;
    }
    let mut output1 = dct4_forward(input1);
    output1[0] = output1[0].mul(Lane::splat_f32(std::f32::consts::SQRT_2));
    for idx in 0..3 {
        *io.get_mut(0, idx * 2 + 1) = output1[idx].add(output1[idx + 1]);
    }
    *io.get_mut(0, 7) = output1[3];
}

#[inline]
unsafe fn dct8_inverse(io: &mut MutableSubgrid<'_, Lane>) {
    assert!(io.height() == 8);
    let sec = dct_common::sec_half_small(8);

    let input0 = [io.get(0, 0), io.get(0, 2), io.get(0, 4), io.get(0, 6)];
    let input1 = [
        io.get(0, 1).mul(Lane::splat_f32(std::f32::consts::SQRT_2)),
        io.get(0, 3).add(io.get(0, 1)),
        io.get(0, 5).add(io.get(0, 3)),
        io.get(0, 7).add(io.get(0, 5)),
    ];
    let output0 = dct4_inverse(input0);
    let output1 = dct4_inverse(input1);
    for (idx, &sec) in sec.iter().enumerate() {
        let r = output1[idx].mul(Lane::splat_f32(sec));
        *io.get_mut(0, idx) = output0[idx].add(r);
        *io.get_mut(0, 7 - idx) = output0[idx].sub(r);
    }
}

unsafe fn dct(io: &mut [Lane], scratch: &mut [Lane], direction: DctDirection) {
    let n = io.len();
    assert!(scratch.len() == n);

    if n == 0 {
        return;
    }
    if n == 1 {
        return;
    }

    if n == 2 {
        let tmp0 = io[0].add(io[1]);
        let tmp1 = io[0].sub(io[1]);
        if direction == DctDirection::Forward {
            let half = Lane::splat_f32(0.5);
            io[0] = tmp0.mul(half);
            io[1] = tmp1.mul(half);
        } else {
            io[0] = tmp0;
            io[1] = tmp1;
        }
        return;
    }

    if n == 4 {
        if direction == DctDirection::Forward {
            io.copy_from_slice(&dct4_forward([io[0], io[1], io[2], io[3]]));
        } else {
            io.copy_from_slice(&dct4_inverse([io[0], io[1], io[2], io[3]]));
        }
        return;
    }

    if n == 8 {
        if direction == DctDirection::Forward {
            dct8_forward(&mut MutableSubgrid::from_buf(io, 1, 8, 1));
        } else {
            dct8_inverse(&mut MutableSubgrid::from_buf(io, 1, 8, 1));
        }
        return;
    }

    assert!(n.is_power_of_two());

    let sqrt2 = Lane::splat_f32(std::f32::consts::SQRT_2);
    if direction == DctDirection::Forward {
        let (input0, input1) = scratch.split_at_mut(n / 2);
        for (idx, &sec) in dct_common::sec_half(n).iter().enumerate() {
            input0[idx] = io[idx].add(io[n - idx - 1]).mul(Lane::splat_f32(0.5));
            input1[idx] = io[idx].sub(io[n - idx - 1]).mul(Lane::splat_f32(sec / 2.0));
        }
        let (output0, output1) = io.split_at_mut(n / 2);
        dct(input0, output0, DctDirection::Forward);
        dct(input1, output1, DctDirection::Forward);
        for (idx, v) in input0.iter().enumerate() {
            io[idx * 2] = *v;
        }
        input1[0] = input1[0].mul(sqrt2);
        for idx in 0..(n / 2 - 1) {
            io[idx * 2 + 1] = input1[idx].add(input1[idx + 1]);
        }
        io[n - 1] = input1[n / 2 - 1];
    } else {
        let (input0, input1) = scratch.split_at_mut(n / 2);
        for idx in 1..(n / 2) {
            let idx = n / 2 - idx;
            input0[idx] = io[idx * 2];
            input1[idx] = io[idx * 2 + 1].add(io[idx * 2 - 1]);
        }
        input0[0] = io[0];
        input1[0] = io[1].mul(sqrt2);
        let (output0, output1) = io.split_at_mut(n / 2);
        dct(input0, output0, DctDirection::Inverse);
        dct(input1, output1, DctDirection::Inverse);
        for (idx, &sec) in dct_common::sec_half(n).iter().enumerate() {
            let r = input1[idx].mul(Lane::splat_f32(sec));
            output0[idx] = input0[idx].add(r);
            output1[n / 2 - idx - 1] = input0[idx].sub(r);
        }
    }
}