#![allow(clippy::too_many_arguments)]
#![cfg_attr(not(feature = "unchecked"), forbid(unsafe_code))]
#![cfg_attr(feature = "unchecked", deny(unsafe_code))]
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
use archmage::{Arm64, arcane, rite};
#[cfg(target_arch = "aarch64")]
use safe_unaligned_simd::aarch64 as safe_simd;
use super::itx_arm_neon_8x8::{smull_smlal_q, smull_smlsl_q, sqrshrn_pair, transpose_8x8h};
use super::itx_arm_neon_16x16::idct_16_q;
use super::itx_arm_neon_common::IDCT_COEFFS;
#[cfg(target_arch = "aarch64")]
pub(crate) type V16 = [int16x8_t; 16];
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn idct32_odd_q(v: V16) -> V16 {
let c0 = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IDCT_COEFFS[16..24]).unwrap());
let c1 = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IDCT_COEFFS[24..32]).unwrap());
let (lo, hi) = smull_smlsl_q(v[0], v[15], c0, 0, 1);
let t16a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[0], v[15], c0, 1, 0);
let t31a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[8], v[7], c0, 2, 3);
let t17a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[8], v[7], c0, 3, 2);
let t30a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[4], v[11], c0, 4, 5);
let t18a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[4], v[11], c0, 5, 4);
let t29a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[12], v[3], c0, 6, 7);
let t19a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[12], v[3], c0, 7, 6);
let t28a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[2], v[13], c1, 0, 1);
let t20a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[2], v[13], c1, 1, 0);
let t27a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[10], v[5], c1, 2, 3);
let t21a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[10], v[5], c1, 3, 2);
let t26a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[6], v[9], c1, 4, 5);
let t22a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[6], v[9], c1, 5, 4);
let t25a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[14], v[1], c1, 6, 7);
let t23a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[14], v[1], c1, 7, 6);
let t24a = sqrshrn_pair(lo, hi);
let c_main = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IDCT_COEFFS[0..8]).unwrap());
let s17 = vqsubq_s16(t16a, t17a); let s16 = vqaddq_s16(t16a, t17a); let s30 = vqsubq_s16(t31a, t30a); let s31 = vqaddq_s16(t31a, t30a); let s18 = vqsubq_s16(t19a, t18a); let s19 = vqaddq_s16(t19a, t18a); let s20 = vqaddq_s16(t20a, t21a); let s21 = vqsubq_s16(t20a, t21a); let s22 = vqsubq_s16(t23a, t22a); let s23 = vqaddq_s16(t23a, t22a); let s24 = vqaddq_s16(t24a, t25a); let s25 = vqsubq_s16(t24a, t25a); let s26 = vqsubq_s16(t27a, t26a); let s27 = vqaddq_s16(t27a, t26a); let s28 = vqaddq_s16(t28a, t29a); let s29 = vqsubq_s16(t28a, t29a);
let (lo, hi) = smull_smlsl_q(s30, s17, c_main, 4, 5);
let u17a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(s30, s17, c_main, 5, 4);
let u30a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(s29, s18, c_main, 5, 4);
let u18a = sqrshrn_pair(vnegq_s32(lo), vnegq_s32(hi));
let (lo, hi) = smull_smlsl_q(s29, s18, c_main, 4, 5);
let u29a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(s26, s21, c_main, 6, 7);
let u21a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(s26, s21, c_main, 7, 6);
let u26a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(s25, s22, c_main, 7, 6);
let u22a = sqrshrn_pair(vnegq_s32(lo), vnegq_s32(hi));
let (lo, hi) = smull_smlsl_q(s25, s22, c_main, 6, 7);
let u25a = sqrshrn_pair(lo, hi);
let w30 = vqaddq_s16(u30a, u29a); let w29 = vqsubq_s16(u30a, u29a); let w18 = vqsubq_s16(u17a, u18a); let w17 = vqaddq_s16(u17a, u18a); let w19a = vqsubq_s16(s16, s19); let w16a = vqaddq_s16(s16, s19); let w20a = vqsubq_s16(s23, s20); let w23a = vqaddq_s16(s23, s20); let w21 = vqsubq_s16(u22a, u21a); let w22 = vqaddq_s16(u22a, u21a); let w24a = vqaddq_s16(s24, s27); let w27a = vqsubq_s16(s24, s27); let w25 = vqaddq_s16(u25a, u26a); let w26 = vqsubq_s16(u25a, u26a); let w28a = vqsubq_s16(s31, s28); let w31a = vqaddq_s16(s31, s28);
let (lo, hi) = smull_smlsl_q(w29, w18, c_main, 2, 3);
let x18a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(w29, w18, c_main, 3, 2);
let x29a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(w28a, w19a, c_main, 2, 3);
let x19 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(w28a, w19a, c_main, 3, 2);
let x28 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(w27a, w20a, c_main, 3, 2);
let x20 = sqrshrn_pair(vnegq_s32(lo), vnegq_s32(hi));
let (lo, hi) = smull_smlsl_q(w27a, w20a, c_main, 2, 3);
let x27 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(w26, w21, c_main, 3, 2);
let x21a = sqrshrn_pair(vnegq_s32(lo), vnegq_s32(hi));
let (lo, hi) = smull_smlsl_q(w26, w21, c_main, 2, 3);
let x26a = sqrshrn_pair(lo, hi);
let y23 = vqsubq_s16(w16a, w23a); let y16 = vqaddq_s16(w16a, w23a); let y24 = vqsubq_s16(w31a, w24a); let y31 = vqaddq_s16(w31a, w24a); let y22a = vqsubq_s16(w17, w22); let y17a = vqaddq_s16(w17, w22); let y30a = vqaddq_s16(w30, w25); let y25a = vqsubq_s16(w30, w25); let y21 = vqsubq_s16(x18a, x21a); let y18 = vqaddq_s16(x18a, x21a); let y19a = vqaddq_s16(x19, x20); let y20a = vqsubq_s16(x19, x20); let y29 = vqaddq_s16(x29a, x26a); let y26 = vqsubq_s16(x29a, x26a); let y28a = vqaddq_s16(x28, x27); let y27a = vqsubq_s16(x28, x27);
let (lo, hi) = smull_smlsl_q(y27a, y20a, c_main, 0, 0);
let z20 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(y27a, y20a, c_main, 0, 0);
let z27 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(y26, y21, c_main, 0, 0);
let z26a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(y26, y21, c_main, 0, 0);
let z21a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(y25a, y22a, c_main, 0, 0);
let z22 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(y25a, y22a, c_main, 0, 0);
let z25 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(y24, y23, c_main, 0, 0);
let z23a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(y24, y23, c_main, 0, 0);
let z24a = sqrshrn_pair(lo, hi);
[
y16, y17a, y18, y19a, z20, z21a, z22, z23a, z24a, z25, z26a, z27, y28a, y29, y30a, y31, ]
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn idct_32_full(even: &V16, odd: &V16) -> [int16x8_t; 32] {
let mut out = [vdupq_n_s16(0); 32];
for i in 0..16 {
out[i] = vqaddq_s16(even[i], odd[15 - i]);
out[31 - i] = vqsubq_s16(even[i], odd[15 - i]);
}
out
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn horz_dct_32x8(
coeff: &mut [i16],
coeff_base: usize,
coeff_stride: usize,
scratch: &mut [i16],
scratch_base: usize,
shift: i16,
) {
let zero_vec = vdupq_n_s16(0);
let mut even_in: V16 = [zero_vec; 16];
for c in 0..16 {
let col = c * 2; let base = coeff_base + col * coeff_stride;
let arr: [i16; 8] = coeff[base..base + 8].try_into().unwrap();
even_in[c] = safe_simd::vld1q_s16(&arr);
coeff[base..base + 8].fill(0);
}
let even_out = idct_16_q(even_in);
let (et0, et1, et2, et3, et4, et5, et6, et7) = transpose_8x8h(
even_out[0],
even_out[1],
even_out[2],
even_out[3],
even_out[4],
even_out[5],
even_out[6],
even_out[7],
);
let (et8, et9, et10, et11, et12, et13, et14, et15) = transpose_8x8h(
even_out[8],
even_out[9],
even_out[10],
even_out[11],
even_out[12],
even_out[13],
even_out[14],
even_out[15],
);
let even_t = [
[et0, et8],
[et1, et9],
[et2, et10],
[et3, et11],
[et4, et12],
[et5, et13],
[et6, et14],
[et7, et15],
];
let mut odd_in: V16 = [zero_vec; 16];
for c in 0..16 {
let col = c * 2 + 1; let base = coeff_base + col * coeff_stride;
let arr: [i16; 8] = coeff[base..base + 8].try_into().unwrap();
odd_in[c] = safe_simd::vld1q_s16(&arr);
coeff[base..base + 8].fill(0);
}
let odd_out = idct32_odd_q(odd_in);
let (ot15, ot14, ot13, ot12, ot11, ot10, ot9, ot8) = transpose_8x8h(
odd_out[15],
odd_out[14],
odd_out[13],
odd_out[12],
odd_out[11],
odd_out[10],
odd_out[9],
odd_out[8],
);
let (ot7, ot6, ot5, ot4, ot3, ot2, ot1, ot0) = transpose_8x8h(
odd_out[7], odd_out[6], odd_out[5], odd_out[4], odd_out[3], odd_out[2], odd_out[1],
odd_out[0],
);
let odd_t_hi = [ot15, ot14, ot13, ot12, ot11, ot10, ot9, ot8];
let odd_t_lo = [ot7, ot6, ot5, ot4, ot3, ot2, ot1, ot0];
for row in 0..8 {
let e_lo = even_t[row][0];
let e_hi = even_t[row][1];
let o_hi = odd_t_hi[row];
let o_lo = odd_t_lo[row];
let first_lo = vqaddq_s16(e_lo, o_hi);
let first_hi = vqaddq_s16(e_hi, o_lo);
let sub_lo = vqsubq_s16(e_lo, o_hi);
let sub_hi = vqsubq_s16(e_hi, o_lo);
let rev_lo = rev128_s16(sub_hi);
let rev_hi = rev128_s16(sub_lo);
let r0 = vrshrq_n_s16::<2>(first_lo);
let r1 = vrshrq_n_s16::<2>(first_hi);
let r2 = vrshrq_n_s16::<2>(rev_lo);
let r3 = vrshrq_n_s16::<2>(rev_hi);
let soff = scratch_base + row * 32;
store_v16(scratch, soff, r0);
store_v16(scratch, soff + 8, r1);
store_v16(scratch, soff + 16, r2);
store_v16(scratch, soff + 24, r3);
}
let _ = shift;
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
#[inline(always)]
pub(crate) fn rev128_s16(v: int16x8_t) -> int16x8_t {
let rev = vrev64q_s16(v);
vextq_s16::<4>(rev, rev)
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
#[inline(always)]
pub(crate) fn store_v16(buf: &mut [i16], off: usize, v: int16x8_t) {
let mut tmp = [0i16; 8];
safe_simd::vst1q_s16(&mut tmp, v);
buf[off..off + 8].copy_from_slice(&tmp);
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
#[inline(always)]
pub(crate) fn load_v16(buf: &[i16], off: usize) -> int16x8_t {
let arr: [i16; 8] = buf[off..off + 8].try_into().unwrap();
safe_simd::vld1q_s16(&arr)
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn vert_dct_add_8x32_8bpc(
dst: &mut [u8],
dst_base: usize,
dst_stride: isize,
scratch: &[i16],
scratch_base: usize,
scratch_stride: usize, ) {
let mut even_in: V16 = [vdupq_n_s16(0); 16];
for i in 0..16 {
let row = i * 2;
even_in[i] = load_v16(scratch, scratch_base + row * scratch_stride);
}
let even_out = idct_16_q(even_in);
let mut odd_in: V16 = [vdupq_n_s16(0); 16];
for i in 0..16 {
let row = i * 2 + 1;
odd_in[i] = load_v16(scratch, scratch_base + row * scratch_stride);
}
let odd_out = idct32_odd_q(odd_in);
for i in 0..16 {
let combined = vqaddq_s16(even_out[i], odd_out[15 - i]);
let shifted = vrshrq_n_s16::<4>(combined);
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
let dst_bytes: [u8; 8] = dst[row_off..row_off + 8].try_into().unwrap();
let dst_u8 = safe_simd::vld1_u8(&dst_bytes);
let sum = vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(shifted), dst_u8));
let result = vqmovun_s16(sum);
let mut out = [0u8; 8];
safe_simd::vst1_u8(&mut out, result);
dst[row_off..row_off + 8].copy_from_slice(&out);
}
for j in 0..16 {
let i = 16 + j; let combined = vqsubq_s16(even_out[15 - j], odd_out[j]);
let shifted = vrshrq_n_s16::<4>(combined);
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
let dst_bytes: [u8; 8] = dst[row_off..row_off + 8].try_into().unwrap();
let dst_u8 = safe_simd::vld1_u8(&dst_bytes);
let sum = vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(shifted), dst_u8));
let result = vqmovun_s16(sum);
let mut out = [0u8; 8];
safe_simd::vst1_u8(&mut out, result);
dst[row_off..row_off + 8].copy_from_slice(&out);
}
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn vert_dct_add_8x32_16bpc(
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
scratch: &[i16],
scratch_base: usize,
scratch_stride: usize,
bitdepth_max: i32,
) {
let bd_max = vdupq_n_s16(bitdepth_max as i16);
let zero = vdupq_n_s16(0);
let mut even_in: V16 = [vdupq_n_s16(0); 16];
for i in 0..16 {
even_in[i] = load_v16(scratch, scratch_base + (i * 2) * scratch_stride);
}
let even_out = idct_16_q(even_in);
let mut odd_in: V16 = [vdupq_n_s16(0); 16];
for i in 0..16 {
odd_in[i] = load_v16(scratch, scratch_base + (i * 2 + 1) * scratch_stride);
}
let odd_out = idct32_odd_q(odd_in);
for i in 0..16 {
let combined = vqaddq_s16(even_out[i], odd_out[15 - i]);
let shifted = vrshrq_n_s16::<4>(combined);
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
let mut arr = [0i16; 8];
for j in 0..8 {
arr[j] = dst[row_off + j] as i16;
}
let d = safe_simd::vld1q_s16(&arr);
let sum = vqaddq_s16(d, shifted);
let clamped = vminq_s16(vmaxq_s16(sum, zero), bd_max);
let mut out = [0i16; 8];
safe_simd::vst1q_s16(&mut out, clamped);
for j in 0..8 {
dst[row_off + j] = out[j] as u16;
}
}
for j in 0..16 {
let i = 16 + j;
let combined = vqsubq_s16(even_out[15 - j], odd_out[j]);
let shifted = vrshrq_n_s16::<4>(combined);
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
let mut arr = [0i16; 8];
for k in 0..8 {
arr[k] = dst[row_off + k] as i16;
}
let d = safe_simd::vld1q_s16(&arr);
let sum = vqaddq_s16(d, shifted);
let clamped = vminq_s16(vmaxq_s16(sum, zero), bd_max);
let mut out = [0i16; 8];
safe_simd::vst1q_s16(&mut out, clamped);
for k in 0..8 {
dst[row_off + k] = out[k] as u16;
}
}
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn dc_only_32x32_8bpc(dst: &mut [u8], dst_base: usize, dst_stride: isize, coeff: &mut [i16]) {
let dc = coeff[0];
coeff[0] = 0;
let scale = vdupq_n_s16((2896 * 8) as i16);
let v = vdupq_n_s16(dc);
let v = vqrdmulhq_s16(v, scale);
let v = vrshrq_n_s16::<2>(v); let v = vqrdmulhq_s16(v, scale);
let v = vrshrq_n_s16::<4>(v);
for i in 0..32 {
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
for half in 0..4 {
let off = row_off + half * 8;
let dst_bytes: [u8; 8] = dst[off..off + 8].try_into().unwrap();
let dst_u8 = safe_simd::vld1_u8(&dst_bytes);
let sum = vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(v), dst_u8));
let result = vqmovun_s16(sum);
let mut out = [0u8; 8];
safe_simd::vst1_u8(&mut out, result);
dst[off..off + 8].copy_from_slice(&out);
}
}
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn dc_only_32x32_16bpc(
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i32],
bitdepth_max: i32,
) {
let dc_val = coeff[0];
coeff[0] = 0;
let scale = 2896i32 * 8;
let mut dc = ((dc_val as i64 * scale as i64 + 16384) >> 15) as i32;
dc = (dc + 2) >> 2; dc = ((dc as i64 * scale as i64 + 16384) >> 15) as i32;
dc = (dc + 8) >> 4;
let dc = dc.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
let dc_vec = vdupq_n_s16(dc);
let bd_max = vdupq_n_s16(bitdepth_max as i16);
let zero = vdupq_n_s16(0);
for i in 0..32 {
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
for half in 0..4 {
let off = row_off + half * 8;
let mut arr = [0i16; 8];
for j in 0..8 {
arr[j] = dst[off + j] as i16;
}
let d = safe_simd::vld1q_s16(&arr);
let sum = vqaddq_s16(d, dc_vec);
let clamped = vminq_s16(vmaxq_s16(sum, zero), bd_max);
let mut out = [0i16; 8];
safe_simd::vst1q_s16(&mut out, clamped);
for j in 0..8 {
dst[off + j] = out[j] as u16;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn inv_txfm_add_identity_identity_32x32_8bpc_neon_inner(
_token: Arm64,
dst: &mut [u8],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i16],
eob: i32,
_bitdepth_max: i32,
) {
identity_32x32_8bpc_impl(dst, dst_base, dst_stride, coeff, eob);
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn identity_32x32_8bpc_impl(
dst: &mut [u8],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i16],
eob: i32,
) {
let eob_row_thresholds: [i32; 4] = [36, 136, 300, 1024];
let eob_col_thresholds: [i32; 4] = [36, 136, 300, 1024];
for rg in 0..4 {
if rg > 0 && eob < eob_row_thresholds[rg - 1] {
break;
}
let row_start = rg * 8;
for cg in 0..4 {
if cg > 0 && eob < eob_col_thresholds[cg - 1] {
break;
}
let col_start = cg * 8;
let zero_vec = vdupq_n_s16(0);
let mut v: [int16x8_t; 8] = [zero_vec; 8];
for c in 0..8 {
let col = col_start + c;
let base = col * 32 + row_start;
let arr: [i16; 8] = coeff[base..base + 8].try_into().unwrap();
v[c] = safe_simd::vld1q_s16(&arr);
coeff[base..base + 8].fill(0);
}
let (r0, r1, r2, r3, r4, r5, r6, r7) =
transpose_8x8h(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]);
let rows = [r0, r1, r2, r3, r4, r5, r6, r7];
for r in 0..8 {
let shifted = vrshrq_n_s16::<2>(rows[r]);
let row_off =
dst_base.wrapping_add_signed((row_start + r) as isize * dst_stride) + col_start;
let dst_bytes: [u8; 8] = dst[row_off..row_off + 8].try_into().unwrap();
let dst_u8 = safe_simd::vld1_u8(&dst_bytes);
let sum = vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(shifted), dst_u8));
let result = vqmovun_s16(sum);
let mut out = [0u8; 8];
safe_simd::vst1_u8(&mut out, result);
dst[row_off..row_off + 8].copy_from_slice(&out);
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn inv_txfm_add_identity_identity_32x32_16bpc_neon_inner(
_token: Arm64,
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i32],
eob: i32,
bitdepth_max: i32,
) {
let eob_thresholds: [i32; 4] = [36, 136, 300, 1024];
let bd_max = bitdepth_max;
for rg in 0..4 {
if rg > 0 && eob < eob_thresholds[rg - 1] {
break;
}
let row_start = rg * 8;
for cg in 0..4 {
if cg > 0 && eob < eob_thresholds[cg - 1] {
break;
}
let col_start = cg * 8;
let mut block = [[0i32; 8]; 8];
for c in 0..8 {
let col = col_start + c;
for r in 0..8 {
let row = row_start + r;
block[c][r] = coeff[col * 32 + row];
coeff[col * 32 + row] = 0;
}
}
for r in 0..8 {
let row_off = dst_base.wrapping_add_signed((row_start + r) as isize * dst_stride);
for c in 0..8 {
let val = (block[c][r] + 2) >> 2; let d = dst[row_off + col_start + c] as i32;
let result = (d + val).clamp(0, bd_max);
dst[row_off + col_start + c] = result as u16;
}
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn inv_txfm_add_dct_dct_32x32_8bpc_neon_inner(
_token: Arm64,
dst: &mut [u8],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i16],
eob: i32,
_bitdepth_max: i32,
) {
if eob == 0 {
dc_only_32x32_8bpc(dst, dst_base, dst_stride, coeff);
return;
}
let eob_thresholds: [i32; 4] = [36, 136, 300, 1024];
let mut scratch = [0i16; 1024];
for group in 0..4 {
let row_start = group * 8;
if group > 0 && eob < eob_thresholds[group - 1] {
break;
}
horz_dct_32x8(
coeff,
row_start,
32, &mut scratch,
row_start * 32, 2, );
}
for group in 0..4 {
let col_start = group * 8;
vert_dct_add_8x32_8bpc(
dst,
dst_base + col_start,
dst_stride,
&scratch,
col_start,
32, );
}
}
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn inv_txfm_add_dct_dct_32x32_16bpc_neon_inner(
_token: Arm64,
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i32],
eob: i32,
bitdepth_max: i32,
) {
if eob == 0 {
dc_only_32x32_16bpc(dst, dst_base, dst_stride, coeff, bitdepth_max);
return;
}
scalar_dct_dct_32x32_16bpc(dst, dst_base, dst_stride, coeff, eob, bitdepth_max);
}
#[allow(dead_code)] fn scalar_dct_dct_32x32_16bpc(
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i32],
_eob: i32,
bitdepth_max: i32,
) {
let mut tmp = [0i32; 1024];
for y in 0..32 {
let mut input = [0i32; 32];
for x in 0..32 {
input[x] = coeff[y + x * 32];
}
let out = scalar_dct32_1d(&input);
for x in 0..32 {
tmp[y * 32 + x] = out[x];
}
}
for x in 0..32 {
let mut input = [0i32; 32];
for y in 0..32 {
input[y] = tmp[y * 32 + x];
}
let out = scalar_dct32_1d(&input);
for y in 0..32 {
let row_off = dst_base.wrapping_add_signed(y as isize * dst_stride);
let d = dst[row_off + x] as i32;
let c = (out[y] + 512) >> 10;
let result = (d + c).clamp(0, bitdepth_max);
dst[row_off + x] = result as u16;
}
}
for i in 0..1024 {
coeff[i] = 0;
}
}
#[allow(dead_code)]
fn scalar_dct32_1d(input: &[i32; 32]) -> [i32; 32] {
let mut even = [0i32; 16];
for i in 0..16 {
even[i] = input[2 * i];
}
let even_out = scalar_dct16_1d(&even);
let mut odd = [0i32; 16];
for i in 0..16 {
odd[i] = input[2 * i + 1];
}
let odd_out = scalar_idct32_odd(&odd);
let mut out = [0i32; 32];
for i in 0..16 {
out[i] = even_out[i] + odd_out[15 - i];
out[31 - i] = even_out[i] - odd_out[15 - i];
}
out
}
#[allow(dead_code)]
fn scalar_dct16_1d(input: &[i32; 16]) -> [i32; 16] {
let mut even = [0i32; 8];
for i in 0..8 {
even[i] = input[2 * i];
}
let even_out = scalar_dct8_1d(&even);
let c = [401i32, 4076, 3166, 2598, 1931, 3612, 3920, 1189];
let t8a = (input[1] * c[0] - input[15] * c[1] + 2048) >> 12;
let t15a = (input[1] * c[1] + input[15] * c[0] + 2048) >> 12;
let t9a = (input[9] * c[2] - input[7] * c[3] + 2048) >> 12;
let t14a = (input[9] * c[3] + input[7] * c[2] + 2048) >> 12;
let t10a = (input[5] * c[4] - input[11] * c[5] + 2048) >> 12;
let t13a = (input[5] * c[5] + input[11] * c[4] + 2048) >> 12;
let t11a = (input[13] * c[6] - input[3] * c[7] + 2048) >> 12;
let t12a = (input[13] * c[7] + input[3] * c[6] + 2048) >> 12;
let t8 = t8a + t9a;
let t9 = t8a - t9a;
let t10 = t11a - t10a;
let t11 = t11a + t10a;
let t12 = t12a + t13a;
let t13 = t12a - t13a;
let t14 = t15a - t14a;
let t15 = t15a + t14a;
let t9a = (t14 * 1567 - t9 * 3784 + 2048) >> 12;
let t14a = (t14 * 3784 + t9 * 1567 + 2048) >> 12;
let t13a = (t13 * 1567 - t10 * 3784 + 2048) >> 12;
let t10a = -((t13 * 3784 + t10 * 1567 + 2048) >> 12);
let t8a = t8 + t11;
let t11a = t8 - t11;
let t9b = t9a + t10a;
let t10b = t9a - t10a;
let t12a = t15 - t12;
let t15a = t15 + t12;
let t13b = t14a - t13a;
let t14b = t14a + t13a;
let t11_f = (t12a * 2896 - t11a * 2896 + 2048) >> 12;
let t12_f = (t12a * 2896 + t11a * 2896 + 2048) >> 12;
let t10_f = (t13b * 2896 - t10b * 2896 + 2048) >> 12;
let t13_f = (t13b * 2896 + t10b * 2896 + 2048) >> 12;
let mut out = [0i32; 16];
out[0] = even_out[0] + t15a;
out[1] = even_out[1] + t14b;
out[2] = even_out[2] + t13_f;
out[3] = even_out[3] + t12_f;
out[4] = even_out[4] + t11_f;
out[5] = even_out[5] + t10_f;
out[6] = even_out[6] + t9b;
out[7] = even_out[7] + t8a;
out[8] = even_out[7] - t8a;
out[9] = even_out[6] - t9b;
out[10] = even_out[5] - t10_f;
out[11] = even_out[4] - t11_f;
out[12] = even_out[3] - t12_f;
out[13] = even_out[2] - t13_f;
out[14] = even_out[1] - t14b;
out[15] = even_out[0] - t15a;
out
}
#[allow(dead_code)]
fn scalar_dct8_1d(input: &[i32; 8]) -> [i32; 8] {
let even = scalar_dct4_1d(input[0], input[2], input[4], input[6]);
let c = [799i32, 4017, 3406, 2276];
let t4a = (input[1] * c[0] - input[7] * c[1] + 2048) >> 12;
let t7a = (input[1] * c[1] + input[7] * c[0] + 2048) >> 12;
let t5a = (input[5] * c[2] - input[3] * c[3] + 2048) >> 12;
let t6a = (input[5] * c[3] + input[3] * c[2] + 2048) >> 12;
let t4 = t4a + t5a;
let t5 = t4a - t5a;
let t6 = t7a - t6a;
let t7 = t7a + t6a;
let t5a = ((t6 - t5) * 2896 + 2048) >> 12;
let t6a = ((t6 + t5) * 2896 + 2048) >> 12;
[
even[0] + t7,
even[1] + t6a,
even[2] + t5a,
even[3] + t4,
even[3] - t4,
even[2] - t5a,
even[1] - t6a,
even[0] - t7,
]
}
#[allow(dead_code)]
fn scalar_dct4_1d(in0: i32, in1: i32, in2: i32, in3: i32) -> [i32; 4] {
let t0 = ((in0 + in2) * 2896 + 2048) >> 12;
let t1 = ((in0 - in2) * 2896 + 2048) >> 12;
let t2 = (in1 * 1567 - in3 * 3784 + 2048) >> 12;
let t3 = (in1 * 3784 + in3 * 1567 + 2048) >> 12;
[t0 + t3, t1 + t2, t1 - t2, t0 - t3]
}
#[allow(dead_code)]
fn scalar_idct32_odd(v: &[i32; 16]) -> [i32; 16] {
let c0 = [201i32, 4091, 3035, 2751, 1751, 3703, 3857, 1380];
let c1 = [995i32, 3973, 3513, 2106, 2440, 3290, 4052, 601];
let t16a = (v[0] * c0[0] - v[15] * c0[1] + 2048) >> 12;
let t31a = (v[0] * c0[1] + v[15] * c0[0] + 2048) >> 12;
let t17a = (v[8] * c0[2] - v[7] * c0[3] + 2048) >> 12;
let t30a = (v[8] * c0[3] + v[7] * c0[2] + 2048) >> 12;
let t18a = (v[4] * c0[4] - v[11] * c0[5] + 2048) >> 12;
let t29a = (v[4] * c0[5] + v[11] * c0[4] + 2048) >> 12;
let t19a = (v[12] * c0[6] - v[3] * c0[7] + 2048) >> 12;
let t28a = (v[12] * c0[7] + v[3] * c0[6] + 2048) >> 12;
let t20a = (v[2] * c1[0] - v[13] * c1[1] + 2048) >> 12;
let t27a = (v[2] * c1[1] + v[13] * c1[0] + 2048) >> 12;
let t21a = (v[10] * c1[2] - v[5] * c1[3] + 2048) >> 12;
let t26a = (v[10] * c1[3] + v[5] * c1[2] + 2048) >> 12;
let t22a = (v[6] * c1[4] - v[9] * c1[5] + 2048) >> 12;
let t25a = (v[6] * c1[5] + v[9] * c1[4] + 2048) >> 12;
let t23a = (v[14] * c1[6] - v[1] * c1[7] + 2048) >> 12;
let t24a = (v[14] * c1[7] + v[1] * c1[6] + 2048) >> 12;
let s17 = t16a - t17a;
let s16 = t16a + t17a;
let s30 = t31a - t30a;
let s31 = t31a + t30a;
let s18 = t19a - t18a;
let s19 = t19a + t18a;
let s20 = t20a + t21a;
let s21 = t20a - t21a;
let s22 = t23a - t22a;
let s23 = t23a + t22a;
let s24 = t24a + t25a;
let s25 = t24a - t25a;
let s26 = t27a - t26a;
let s27 = t27a + t26a;
let s28 = t28a + t29a;
let s29 = t28a - t29a;
let u17a = (s30 * 799 - s17 * 4017 + 2048) >> 12;
let u30a = (s30 * 4017 + s17 * 799 + 2048) >> 12;
let u18a = -((s29 * 4017 + s18 * 799 + 2048) >> 12);
let u29a = (s29 * 799 - s18 * 4017 + 2048) >> 12;
let u21a = (s26 * 3406 - s21 * 2276 + 2048) >> 12;
let u26a = (s26 * 2276 + s21 * 3406 + 2048) >> 12;
let u22a = -((s25 * 2276 + s22 * 3406 + 2048) >> 12);
let u25a = (s25 * 3406 - s22 * 2276 + 2048) >> 12;
let w30 = u30a + u29a;
let w29 = u30a - u29a;
let w18 = u17a - u18a;
let w17 = u17a + u18a;
let w19a = s16 - s19;
let w16a = s16 + s19;
let w20a = s23 - s20;
let w23a = s23 + s20;
let w21 = u22a - u21a;
let w22 = u22a + u21a;
let w24a = s24 + s27;
let w27a = s24 - s27;
let w25 = u25a + u26a;
let w26 = u25a - u26a;
let w28a = s31 - s28;
let w31a = s31 + s28;
let x18a = (w29 * 1567 - w18 * 3784 + 2048) >> 12;
let x29a = (w29 * 3784 + w18 * 1567 + 2048) >> 12;
let x19 = (w28a * 1567 - w19a * 3784 + 2048) >> 12;
let x28 = (w28a * 3784 + w19a * 1567 + 2048) >> 12;
let x20 = -((w27a * 3784 + w20a * 1567 + 2048) >> 12);
let x27 = (w27a * 1567 - w20a * 3784 + 2048) >> 12;
let x21a = -((w26 * 3784 + w21 * 1567 + 2048) >> 12);
let x26a = (w26 * 1567 - w21 * 3784 + 2048) >> 12;
let y16 = w16a + w23a;
let y23 = w16a - w23a;
let y31 = w31a + w24a;
let y24 = w31a - w24a;
let y17a = w17 + w22;
let y22a = w17 - w22;
let y30a = w30 + w25;
let y25a = w30 - w25;
let y18 = x18a + x21a;
let y21 = x18a - x21a;
let y19a = x19 + x20;
let y20a = x19 - x20;
let y29 = x29a + x26a;
let y26 = x29a - x26a;
let y28a = x28 + x27;
let y27a = x28 - x27;
let z20 = (y27a * 2896 - y20a * 2896 + 2048) >> 12;
let z27 = (y27a * 2896 + y20a * 2896 + 2048) >> 12;
let z26a = (y26 * 2896 + y21 * 2896 + 2048) >> 12;
let z21a = (y26 * 2896 - y21 * 2896 + 2048) >> 12;
let z22 = (y25a * 2896 - y22a * 2896 + 2048) >> 12;
let z25 = (y25a * 2896 + y22a * 2896 + 2048) >> 12;
let z23a = (y24 * 2896 - y23 * 2896 + 2048) >> 12;
let z24a = (y24 * 2896 + y23 * 2896 + 2048) >> 12;
[
y16, y17a, y18, y19a, z20, z21a, z22, z23a, z24a, z25, z26a, z27, y28a, y29, y30a, y31,
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_dct4() {
let out = scalar_dct4_1d(4096, 0, 0, 0);
assert!(
out.iter().all(|&x| (x - out[0]).abs() <= 1),
"DC-only DCT4 should produce equal outputs, got {:?}",
out
);
}
#[test]
fn test_scalar_idct32_odd_zero() {
let input = [0i32; 16];
let out = scalar_idct32_odd(&input);
assert_eq!(out, [0i32; 16]);
}
#[test]
fn test_scalar_dct32_dc_only() {
let mut input = [0i32; 32];
input[0] = 4096;
let out = scalar_dct32_1d(&input);
let mean = out.iter().sum::<i32>() / 32;
for (i, &v) in out.iter().enumerate() {
assert!(
(v - mean).abs() <= 2,
"DCT32 DC-only: output[{}]={} deviates from mean={}",
i,
v,
mean
);
}
}
#[test]
fn test_scalar_dct32_roundtrip() {
let mut input = [0i32; 32];
for i in 0..32 {
input[i] = (i as i32) * 100 - 1600;
}
let transformed = scalar_dct32_1d(&input);
assert!(
transformed.iter().all(|&x| x.abs() < 100000),
"Transform output should be bounded"
);
}
}