edgefirst-codec 0.23.1

Image codec for decoding JPEG/PNG into pre-allocated EdgeFirst tensors
Documentation
// SPDX-FileCopyrightText: Copyright 2026 Au-Zone Technologies
// SPDX-License-Identifier: Apache-2.0

//! SSE4.1-optimised 8×8 IDCT.
//!
//! Upgrades over SSE2: native `_mm_mullo_epi32` (1 instruction vs 4),
//! `_mm_min_epi32`/`_mm_max_epi32` for branchless clamping (2 vs 5),
//! and `_mm_packus_epi32` for direct u8 packing.

#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

const PASS1_BITS: i32 = 2;
const CONST_BITS: i32 = 13;

const FIX_0_298: i32 = 2446;
const FIX_0_390: i32 = 3196;
const FIX_0_541: i32 = 4433;
const FIX_0_765: i32 = 6270;
const FIX_0_899: i32 = 7373;
const FIX_1_175: i32 = 9633;
const FIX_1_501: i32 = 12299;
const FIX_1_847: i32 = 15137;
const FIX_1_961: i32 = 16069;
const FIX_2_053: i32 = 16819;
const FIX_2_562: i32 = 20995;
const FIX_3_072: i32 = 25172;

macro_rules! transpose4x4 {
    ($r0:expr, $r1:expr, $r2:expr, $r3:expr) => {{
        let t0 = _mm_unpacklo_epi32($r0, $r1);
        let t1 = _mm_unpackhi_epi32($r0, $r1);
        let t2 = _mm_unpacklo_epi32($r2, $r3);
        let t3 = _mm_unpackhi_epi32($r2, $r3);
        $r0 = _mm_unpacklo_epi64(t0, t2);
        $r1 = _mm_unpackhi_epi64(t0, t2);
        $r2 = _mm_unpacklo_epi64(t1, t3);
        $r3 = _mm_unpackhi_epi64(t1, t3);
    }};
}

#[inline(always)]
unsafe fn even_part(
    s0: __m128i,
    s2: __m128i,
    s4: __m128i,
    s6: __m128i,
    bias: __m128i,
) -> (__m128i, __m128i, __m128i, __m128i) {
    let tmp0 = _mm_slli_epi32::<CONST_BITS>(s0);
    let tmp2 = _mm_slli_epi32::<CONST_BITS>(s4);

    let tmp10 = _mm_add_epi32(tmp0, tmp2);
    let tmp11 = _mm_sub_epi32(tmp0, tmp2);

    let z1 = _mm_add_epi32(s2, s6);
    let fix_0541 = _mm_set1_epi32(FIX_0_541);
    let fix_0765 = _mm_set1_epi32(FIX_0_765);
    let fix_1847 = _mm_set1_epi32(FIX_1_847);

    let tmp13 = _mm_add_epi32(_mm_mullo_epi32(z1, fix_0541), _mm_mullo_epi32(s2, fix_0765));
    let tmp12 = _mm_sub_epi32(_mm_mullo_epi32(z1, fix_0541), _mm_mullo_epi32(s6, fix_1847));

    let t0 = _mm_add_epi32(_mm_add_epi32(tmp10, tmp13), bias);
    let t3 = _mm_add_epi32(_mm_sub_epi32(tmp10, tmp13), bias);
    let t1 = _mm_add_epi32(_mm_add_epi32(tmp11, tmp12), bias);
    let t2 = _mm_add_epi32(_mm_sub_epi32(tmp11, tmp12), bias);

    (t0, t1, t2, t3)
}

#[inline(always)]
unsafe fn odd_part(
    s1: __m128i,
    s3: __m128i,
    s5: __m128i,
    s7: __m128i,
) -> (__m128i, __m128i, __m128i, __m128i) {
    let z1 = _mm_add_epi32(s7, s1);
    let z2 = _mm_add_epi32(s5, s3);
    let z3 = _mm_add_epi32(s7, s3);
    let z4 = _mm_add_epi32(s5, s1);
    let z5 = _mm_mullo_epi32(_mm_add_epi32(z3, z4), _mm_set1_epi32(FIX_1_175));

    let p7 = _mm_mullo_epi32(s7, _mm_set1_epi32(FIX_0_298));
    let p5 = _mm_mullo_epi32(s5, _mm_set1_epi32(FIX_2_053));
    let p3 = _mm_mullo_epi32(s3, _mm_set1_epi32(FIX_3_072));
    let p1 = _mm_mullo_epi32(s1, _mm_set1_epi32(FIX_1_501));

    let z1 = _mm_mullo_epi32(z1, _mm_set1_epi32(-FIX_0_899));
    let z2 = _mm_mullo_epi32(z2, _mm_set1_epi32(-FIX_2_562));
    let z3 = _mm_add_epi32(_mm_mullo_epi32(z3, _mm_set1_epi32(-FIX_1_961)), z5);
    let z4 = _mm_add_epi32(_mm_mullo_epi32(z4, _mm_set1_epi32(-FIX_0_390)), z5);

    (
        _mm_add_epi32(_mm_add_epi32(p7, z1), z3),
        _mm_add_epi32(_mm_add_epi32(p5, z2), z4),
        _mm_add_epi32(_mm_add_epi32(p3, z2), z3),
        _mm_add_epi32(_mm_add_epi32(p1, z1), z4),
    )
}

/// SSE4.1 8×8 IDCT.
pub fn idct_8x8_sse41(coeffs: &[i32; 64], output: &mut [u8], stride: usize) {
    unsafe { idct_8x8_sse41_inner(coeffs, output, stride) }
}

#[target_feature(enable = "sse4.1")]
unsafe fn idct_8x8_sse41_inner(coeffs: &[i32; 64], output: &mut [u8], stride: usize) {
    let mut workspace = [0i32; 64];

    // Pass 1: columns (two groups of 4)
    let half = _mm_set1_epi32(1 << (CONST_BITS - 1));

    for col_group in 0..2 {
        let base = col_group * 4;

        let s0 = _mm_loadu_si128(coeffs.as_ptr().add(base) as *const __m128i);
        let s1 = _mm_loadu_si128(coeffs.as_ptr().add(8 + base) as *const __m128i);
        let s2 = _mm_loadu_si128(coeffs.as_ptr().add(16 + base) as *const __m128i);
        let s3 = _mm_loadu_si128(coeffs.as_ptr().add(24 + base) as *const __m128i);
        let s4 = _mm_loadu_si128(coeffs.as_ptr().add(32 + base) as *const __m128i);
        let s5 = _mm_loadu_si128(coeffs.as_ptr().add(40 + base) as *const __m128i);
        let s6 = _mm_loadu_si128(coeffs.as_ptr().add(48 + base) as *const __m128i);
        let s7 = _mm_loadu_si128(coeffs.as_ptr().add(56 + base) as *const __m128i);

        let (t0, t1, t2, t3) = even_part(s0, s2, s4, s6, half);
        let (o0, o1, o2, o3) = odd_part(s1, s3, s5, s7);

        let shift = 13 - 2; // CONST_BITS - PASS1_BITS
        let r0 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t0, o3));
        let r7 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t0, o3));
        let r1 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t1, o2));
        let r6 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t1, o2));
        let r2 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t2, o1));
        let r5 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t2, o1));
        let r3 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t3, o0));
        let r4 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t3, o0));

        _mm_storeu_si128(workspace.as_mut_ptr().add(base) as *mut __m128i, r0);
        _mm_storeu_si128(workspace.as_mut_ptr().add(8 + base) as *mut __m128i, r1);
        _mm_storeu_si128(workspace.as_mut_ptr().add(16 + base) as *mut __m128i, r2);
        _mm_storeu_si128(workspace.as_mut_ptr().add(24 + base) as *mut __m128i, r3);
        _mm_storeu_si128(workspace.as_mut_ptr().add(32 + base) as *mut __m128i, r4);
        _mm_storeu_si128(workspace.as_mut_ptr().add(40 + base) as *mut __m128i, r5);
        _mm_storeu_si128(workspace.as_mut_ptr().add(48 + base) as *mut __m128i, r6);
        _mm_storeu_si128(workspace.as_mut_ptr().add(56 + base) as *mut __m128i, r7);

        let _ = shift; // suppress unused warning (shift used in const generic)
    }

    // Pass 2: rows → u8 output using SSE4.1 min/max clamping + pack
    let range_shift = CONST_BITS + PASS1_BITS + 3;
    let bias_val = (1 << (range_shift - 1)) + (128 << range_shift);
    let bias = _mm_set1_epi32(bias_val);
    let zero = _mm_setzero_si128();
    let max255 = _mm_set1_epi32(255);

    for row_group in 0..2 {
        let row_base = row_group * 4 * 8;

        let mut r0l = _mm_loadu_si128(workspace.as_ptr().add(row_base) as *const __m128i);
        let mut r1l = _mm_loadu_si128(workspace.as_ptr().add(row_base + 8) as *const __m128i);
        let mut r2l = _mm_loadu_si128(workspace.as_ptr().add(row_base + 16) as *const __m128i);
        let mut r3l = _mm_loadu_si128(workspace.as_ptr().add(row_base + 24) as *const __m128i);
        transpose4x4!(r0l, r1l, r2l, r3l);

        let mut r0r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 4) as *const __m128i);
        let mut r1r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 8 + 4) as *const __m128i);
        let mut r2r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 16 + 4) as *const __m128i);
        let mut r3r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 24 + 4) as *const __m128i);
        transpose4x4!(r0r, r1r, r2r, r3r);

        let (t0, t1, t2, t3) = even_part(r0l, r2l, r0r, r2r, bias);
        let (o0, o1, o2, o3) = odd_part(r1l, r3l, r1r, r3r);

        // SSE4.1 clamp: max(0, min(255, val >> shift))
        let clamp = |v: __m128i| -> __m128i {
            let shifted = _mm_srai_epi32::<{ 13 + 2 + 3 }>(v);
            _mm_min_epi32(_mm_max_epi32(shifted, zero), max255)
        };

        let out0 = clamp(_mm_add_epi32(t0, o3));
        let out7 = clamp(_mm_sub_epi32(t0, o3));
        let out1 = clamp(_mm_add_epi32(t1, o2));
        let out6 = clamp(_mm_sub_epi32(t1, o2));
        let out2 = clamp(_mm_add_epi32(t2, o1));
        let out5 = clamp(_mm_sub_epi32(t2, o1));
        let out3 = clamp(_mm_add_epi32(t3, o0));
        let out4 = clamp(_mm_sub_epi32(t3, o0));

        // Write output via temp array (same approach as SSE2/NEON)
        let mut tmp = [0i32; 32];
        _mm_storeu_si128(tmp.as_mut_ptr() as *mut __m128i, out0);
        _mm_storeu_si128(tmp.as_mut_ptr().add(4) as *mut __m128i, out1);
        _mm_storeu_si128(tmp.as_mut_ptr().add(8) as *mut __m128i, out2);
        _mm_storeu_si128(tmp.as_mut_ptr().add(12) as *mut __m128i, out3);
        _mm_storeu_si128(tmp.as_mut_ptr().add(16) as *mut __m128i, out4);
        _mm_storeu_si128(tmp.as_mut_ptr().add(20) as *mut __m128i, out5);
        _mm_storeu_si128(tmp.as_mut_ptr().add(24) as *mut __m128i, out6);
        _mm_storeu_si128(tmp.as_mut_ptr().add(28) as *mut __m128i, out7);

        for local_row in 0..4 {
            let row_idx = row_group * 4 + local_row;
            let out_ptr = output.as_mut_ptr().add(row_idx * stride);
            for col in 0..8 {
                *out_ptr.add(col) = tmp[col * 4 + local_row] as u8;
            }
        }
    }
}

/// SSE4.1 DC-only IDCT (same as SSE2 — no benefit from SSE4.1 here).
pub fn idct_dc_only_sse41(dc_value: i32, output: &mut [u8], stride: usize) {
    let range_shift = CONST_BITS + PASS1_BITS + 3;
    let round = 1 << (range_shift - 1);
    let bias = round + (128 << range_shift);
    let scaled = dc_value << (CONST_BITS + PASS1_BITS);
    let val = ((scaled + bias) >> range_shift).clamp(0, 255) as u8;

    unsafe {
        let fill = _mm_set1_epi8(val as i8);
        for row in 0..8 {
            _mm_storel_epi64(output.as_mut_ptr().add(row * stride) as *mut __m128i, fill);
        }
    }
}

#[cfg(test)]
#[cfg(target_arch = "x86_64")]
mod tests {
    use super::super::scalar::{idct_8x8_scalar, idct_dc_only_scalar};
    use super::{idct_8x8_sse41, idct_dc_only_sse41};

    fn make_test_coeffs() -> [i32; 64] {
        let mut c = [0i32; 64];
        c[0] = 256;
        c[1] = 32;
        c[8] = 16;
        c[2] = -24;
        c[16] = 20;
        c[9] = -12;
        c[3] = 8;
        c[24] = -8;
        c
    }

    #[test]
    fn idct_8x8_parity() {
        if !is_x86_feature_detected!("sse4.1") {
            eprintln!("SIMD feature not available, skipping");
            return;
        }

        let coeffs = make_test_coeffs();
        let mut scalar_out = [0u8; 64];
        let mut simd_out = [0u8; 64];

        idct_8x8_scalar(&coeffs, &mut scalar_out, 8);
        idct_8x8_sse41(&coeffs, &mut simd_out, 8);

        for i in 0..64 {
            let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
            assert!(
                diff <= 1,
                "parity mismatch at index {i}: scalar={}, simd={}, diff={}",
                scalar_out[i],
                simd_out[i],
                diff
            );
        }
    }

    #[test]
    fn idct_8x8_parity_zero_block() {
        if !is_x86_feature_detected!("sse4.1") {
            eprintln!("SIMD feature not available, skipping");
            return;
        }

        let coeffs = [0i32; 64];
        let mut scalar_out = [0u8; 64];
        let mut simd_out = [0u8; 64];

        idct_8x8_scalar(&coeffs, &mut scalar_out, 8);
        idct_8x8_sse41(&coeffs, &mut simd_out, 8);

        for i in 0..64 {
            let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
            assert!(
                diff <= 1,
                "zero-block parity mismatch at index {i}: scalar={}, simd={}, diff={}",
                scalar_out[i],
                simd_out[i],
                diff
            );
        }
    }

    #[test]
    fn idct_8x8_parity_strided() {
        if !is_x86_feature_detected!("sse4.1") {
            eprintln!("SIMD feature not available, skipping");
            return;
        }

        let coeffs = make_test_coeffs();
        let stride = 16usize;
        let mut scalar_out = vec![0u8; stride * 8];
        let mut simd_out = vec![0u8; stride * 8];

        idct_8x8_scalar(&coeffs, &mut scalar_out, stride);
        idct_8x8_sse41(&coeffs, &mut simd_out, stride);

        for row in 0..8 {
            for col in 0..8 {
                let i = row * stride + col;
                let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
                assert!(
                    diff <= 1,
                    "strided parity mismatch at row={row} col={col}: scalar={}, simd={}, diff={}",
                    scalar_out[i],
                    simd_out[i],
                    diff
                );
            }
        }
    }

    #[test]
    fn idct_dc_only_parity() {
        if !is_x86_feature_detected!("sse4.1") {
            eprintln!("SIMD feature not available, skipping");
            return;
        }

        for dc in [0i32, 8, 64, 128, -64, 255, -255] {
            let mut scalar_out = [0u8; 64];
            let mut simd_out = [0u8; 64];

            idct_dc_only_scalar(dc, &mut scalar_out, 8);
            idct_dc_only_sse41(dc, &mut simd_out, 8);

            for i in 0..64 {
                let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
                assert!(
                    diff <= 1,
                    "dc_only parity mismatch dc={dc} at index {i}: scalar={}, simd={}, diff={}",
                    scalar_out[i],
                    simd_out[i],
                    diff
                );
            }
        }
    }
}