#![allow(clippy::too_many_arguments)]
#![cfg_attr(not(feature = "unchecked"), forbid(unsafe_code))]
#![cfg_attr(feature = "unchecked", deny(unsafe_code))]
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "aarch64")]
use archmage::{Arm64, arcane, rite};
#[cfg(target_arch = "aarch64")]
use safe_unaligned_simd::aarch64 as safe_simd;
use super::itx_arm_neon_8x8::{
iadst_8_q, idct_8_q, smull_smlal_q, smull_smlsl_q, sqrshrn_pair, transpose_8x8h,
};
use super::itx_arm_neon_common::IDCT_COEFFS;
#[cfg(target_arch = "aarch64")]
type V16 = [int16x8_t; 16];
#[cfg(target_arch = "aarch64")]
pub(crate) const IADST16_COEFFS_V0: [i16; 8] = [4091, 201, 3973, 995, 3703, 1751, 3290, 2440];
#[cfg(target_arch = "aarch64")]
pub(crate) const IADST16_COEFFS_V1: [i16; 8] = [2751, 3035, 2106, 3513, 1380, 3857, 601, 4052];
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn idct_16_q(v: V16) -> V16 {
let (e0, e1, e2, e3, e4, e5, e6, e7) =
idct_8_q(v[0], v[2], v[4], v[6], v[8], v[10], v[12], v[14]);
let c1 = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IDCT_COEFFS[8..16]).unwrap());
let (lo, hi) = smull_smlsl_q(v[1], v[15], c1, 0, 1);
let t8a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[1], v[15], c1, 1, 0);
let t15a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[9], v[7], c1, 2, 3);
let t9a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[9], v[7], c1, 3, 2);
let t14a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[5], v[11], c1, 4, 5);
let t10a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[5], v[11], c1, 5, 4);
let t13a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[13], v[3], c1, 6, 7);
let t11a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[13], v[3], c1, 7, 6);
let t12a = sqrshrn_pair(lo, hi);
let t9 = vqsubq_s16(t8a, t9a); let t8 = vqaddq_s16(t8a, t9a); let t14 = vqsubq_s16(t15a, t14a); let t15 = vqaddq_s16(t15a, t14a); let t10 = vqsubq_s16(t11a, t10a); let t11 = vqaddq_s16(t11a, t10a); let t12 = vqaddq_s16(t12a, t13a); let t13 = vqsubq_s16(t12a, t13a);
let c0 = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IDCT_COEFFS[0..8]).unwrap());
let (lo, hi) = smull_smlsl_q(t14, t9, c0, 2, 3);
let t9a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(t14, t9, c0, 3, 2);
let t14a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(t13, t10, c0, 2, 3);
let t13a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(t13, t10, c0, 3, 2);
let t10a = sqrshrn_pair(vnegq_s32(lo), vnegq_s32(hi));
let t11a = vqsubq_s16(t8, t11); let t8a = vqaddq_s16(t8, t11); let t12a = vqsubq_s16(t15, t12); let t15a = vqaddq_s16(t15, t12); let t9b = vqaddq_s16(t9a, t10a); let t10b = vqsubq_s16(t9a, t10a); let t13b = vqsubq_s16(t14a, t13a); let t14b = vqaddq_s16(t14a, t13a);
let (lo, hi) = smull_smlsl_q(t12a, t11a, c0, 0, 0);
let t11_final = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(t12a, t11a, c0, 0, 0);
let t12_final = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(t13b, t10b, c0, 0, 0);
let t10a_final = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(t13b, t10b, c0, 0, 0);
let t13a_final = sqrshrn_pair(lo, hi);
[
vqaddq_s16(e0, t15a), vqaddq_s16(e1, t14b), vqaddq_s16(e2, t13a_final), vqaddq_s16(e3, t12_final), vqaddq_s16(e4, t11_final), vqaddq_s16(e5, t10a_final), vqaddq_s16(e6, t9b), vqaddq_s16(e7, t8a), vqsubq_s16(e7, t8a), vqsubq_s16(e6, t9b), vqsubq_s16(e5, t10a_final), vqsubq_s16(e4, t11_final), vqsubq_s16(e3, t12_final), vqsubq_s16(e2, t13a_final), vqsubq_s16(e1, t14b), vqsubq_s16(e0, t15a), ]
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn iadst_16_q(v: V16) -> V16 {
let c0 = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IADST16_COEFFS_V0[..]).unwrap());
let c1 = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IADST16_COEFFS_V1[..]).unwrap());
let (lo, hi) = smull_smlal_q(v[15], v[0], c0, 0, 1);
let t0 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[15], v[0], c0, 1, 0);
let t1 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[13], v[2], c0, 2, 3);
let t2 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[13], v[2], c0, 3, 2);
let t3 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[11], v[4], c0, 4, 5);
let t4 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[11], v[4], c0, 5, 4);
let t5 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[9], v[6], c0, 6, 7);
let t6 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[9], v[6], c0, 7, 6);
let t7 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[7], v[8], c1, 0, 1);
let t8 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[7], v[8], c1, 1, 0);
let t9 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[5], v[10], c1, 2, 3);
let t10 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[5], v[10], c1, 3, 2);
let t11 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[3], v[12], c1, 4, 5);
let t12 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[3], v[12], c1, 5, 4);
let t13 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(v[1], v[14], c1, 6, 7);
let t14 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(v[1], v[14], c1, 7, 6);
let t15 = sqrshrn_pair(lo, hi);
let ci = safe_simd::vld1q_s16(<&[i16; 8]>::try_from(&IDCT_COEFFS[0..8]).unwrap());
let s8a = vqsubq_s16(t0, t8); let s0a = vqaddq_s16(t0, t8); let s9a = vqsubq_s16(t1, t9); let s1a = vqaddq_s16(t1, t9); let s2a = vqaddq_s16(t2, t10); let s10a = vqsubq_s16(t2, t10); let s3a = vqaddq_s16(t3, t11); let s11a = vqsubq_s16(t3, t11); let s4a = vqaddq_s16(t4, t12); let s12a = vqsubq_s16(t4, t12); let s5a = vqaddq_s16(t5, t13); let s13a = vqsubq_s16(t5, t13); let s6a = vqaddq_s16(t6, t14); let s14a = vqsubq_s16(t6, t14); let s7a = vqaddq_s16(t7, t15); let s15a = vqsubq_s16(t7, t15);
let (lo, hi) = smull_smlal_q(s8a, s9a, ci, 5, 4);
let u8_ = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(s8a, s9a, ci, 4, 5);
let u9 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(s10a, s11a, ci, 7, 6);
let u10 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(s10a, s11a, ci, 6, 7);
let u11 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(s13a, s12a, ci, 5, 4);
let u12 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(s13a, s12a, ci, 4, 5);
let u13 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(s15a, s14a, ci, 7, 6);
let u14 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(s15a, s14a, ci, 6, 7);
let u15 = sqrshrn_pair(lo, hi);
let w4 = vqsubq_s16(s0a, s4a); let w0 = vqaddq_s16(s0a, s4a); let w5 = vqsubq_s16(s1a, s5a); let w1 = vqaddq_s16(s1a, s5a); let w2 = vqaddq_s16(s2a, s6a); let w6 = vqsubq_s16(s2a, s6a); let w3 = vqaddq_s16(s3a, s7a); let w7 = vqsubq_s16(s3a, s7a); let w8a = vqaddq_s16(u8_, u12); let w12a = vqsubq_s16(u8_, u12); let w9a = vqaddq_s16(u9, u13); let w13a = vqsubq_s16(u9, u13); let w10a = vqaddq_s16(u10, u14); let w14a = vqsubq_s16(u10, u14); let w11a = vqaddq_s16(u11, u15); let w15a = vqsubq_s16(u11, u15);
let (lo, hi) = smull_smlal_q(w4, w5, ci, 3, 2);
let x4a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(w4, w5, ci, 2, 3);
let x5a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(w7, w6, ci, 3, 2);
let x6a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(w7, w6, ci, 2, 3);
let x7a = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(w12a, w13a, ci, 3, 2);
let x12 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(w12a, w13a, ci, 2, 3);
let x13 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(w15a, w14a, ci, 3, 2);
let x14 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(w15a, w14a, ci, 2, 3);
let x15 = sqrshrn_pair(lo, hi);
let t2a = vqsubq_s16(w0, w2);
let o0 = vqaddq_s16(w0, w2);
let t3a = vqsubq_s16(w1, w3);
let o15 = vqnegq_s16(vqaddq_s16(w1, w3));
let t15a = vqsubq_s16(x13, x15);
let o13 = vqnegq_s16(vqaddq_s16(x13, x15));
let o2 = vqaddq_s16(x12, x14);
let t14a = vqsubq_s16(x12, x14);
let o1 = vqnegq_s16(vqaddq_s16(w8a, w10a));
let y10 = vqsubq_s16(w8a, w10a);
let o14 = vqaddq_s16(w9a, w11a);
let y11 = vqsubq_s16(w9a, w11a);
let o3 = vqnegq_s16(vqaddq_s16(x4a, x6a));
let y6 = vqsubq_s16(x4a, x6a);
let o12 = vqaddq_s16(x5a, x7a);
let y7 = vqsubq_s16(x5a, x7a);
let (lo, hi) = smull_smlsl_q(t2a, t3a, ci, 0, 0);
let o8 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(t2a, t3a, ci, 0, 0);
let o7_pre = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(t14a, t15a, ci, 0, 0);
let o5_pre = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(t14a, t15a, ci, 0, 0);
let o10 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(y6, y7, ci, 0, 0);
let o4 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(y6, y7, ci, 0, 0);
let o11_pre = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlal_q(y10, y11, ci, 0, 0);
let o6 = sqrshrn_pair(lo, hi);
let (lo, hi) = smull_smlsl_q(y10, y11, ci, 0, 0);
let o9_pre = sqrshrn_pair(lo, hi);
let o7 = vqnegq_s16(o7_pre);
let o5 = vqnegq_s16(o5_pre);
let o11 = vqnegq_s16(o11_pre);
let o9 = vqnegq_s16(o9_pre);
[
o0, o1, o2, o3, o4, o5, o6, o7, o8, o9, o10, o11, o12, o13, o14, o15,
]
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
pub(crate) fn identity_16_q(mut v: V16) -> V16 {
let scale = vdupq_n_s16(((5793 - 4096) * 2 * 8) as i16);
for vi in v.iter_mut() {
let t = vqrdmulhq_s16(*vi, scale);
*vi = vqaddq_s16(*vi, *vi); *vi = vqaddq_s16(*vi, t); }
v
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn transpose_16x16_half(a: [int16x8_t; 8], b: [int16x8_t; 8]) -> ([int16x8_t; 8], [int16x8_t; 8]) {
let (a0, a1, a2, a3, a4, a5, a6, a7) =
transpose_8x8h(a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]);
let (b0, b1, b2, b3, b4, b5, b6, b7) =
transpose_8x8h(b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]);
(
[a0, a1, a2, a3, a4, a5, a6, a7],
[b0, b1, b2, b3, b4, b5, b6, b7],
)
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn add_to_dst_8x16_8bpc(dst: &mut [u8], dst_base: usize, stride: isize, v: V16) {
for (i, &row) in v.iter().enumerate() {
let row_off = dst_base.wrapping_add_signed(i as isize * stride);
let shifted = vrshrq_n_s16::<4>(row);
let dst_bytes: [u8; 8] = dst[row_off..row_off + 8].try_into().unwrap();
let dst_u8 = safe_simd::vld1_u8(&dst_bytes);
let sum = vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(shifted), dst_u8));
let result = vqmovun_s16(sum);
let mut out = [0u8; 8];
safe_simd::vst1_u8(&mut out, result);
dst[row_off..row_off + 8].copy_from_slice(&out);
}
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn add_to_dst_8x16_16bpc(
dst: &mut [u16],
dst_base: usize,
stride: isize,
v: V16,
bitdepth_max: i32,
) {
let bd_max = vdupq_n_s16(bitdepth_max as i16);
let zero = vdupq_n_s16(0);
for (i, &row) in v.iter().enumerate() {
let row_off = dst_base.wrapping_add_signed(i as isize * stride);
let shifted = vrshrq_n_s16::<4>(row);
let dst_arr: [i16; 8] = {
let mut tmp = [0i16; 8];
for j in 0..8 {
tmp[j] = dst[row_off + j] as i16;
}
tmp
};
let dst_vals = safe_simd::vld1q_s16(&dst_arr);
let sum = vqaddq_s16(dst_vals, shifted);
let clamped = vminq_s16(vmaxq_s16(sum, zero), bd_max);
let mut out = [0i16; 8];
safe_simd::vst1q_s16(&mut out, clamped);
for j in 0..8 {
dst[row_off + j] = out[j] as u16;
}
}
}
#[cfg(target_arch = "aarch64")]
#[derive(Clone, Copy)]
pub(crate) enum TxType16 {
Dct,
Adst,
FlipAdst,
Identity,
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn apply_tx16(tx: TxType16, v: V16) -> V16 {
match tx {
TxType16::Dct => idct_16_q(v),
TxType16::Adst => iadst_16_q(v),
TxType16::FlipAdst => {
let mut out = iadst_16_q(v);
out.reverse();
out
}
TxType16::Identity => identity_16_q(v),
}
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn dc_only_16x16_8bpc(dst: &mut [u8], dst_base: usize, dst_stride: isize, coeff: &mut [i16]) {
let dc = coeff[0];
coeff[0] = 0;
let scale = vdupq_n_s16((2896 * 8) as i16);
let v = vdupq_n_s16(dc);
let v = vqrdmulhq_s16(v, scale);
let v = vrshrq_n_s16::<2>(v); let v = vqrdmulhq_s16(v, scale);
let v = vrshrq_n_s16::<4>(v);
for i in 0..16 {
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
let dst_bytes: [u8; 8] = dst[row_off..row_off + 8].try_into().unwrap();
let dst_u8 = safe_simd::vld1_u8(&dst_bytes);
let sum = vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(v), dst_u8));
let result = vqmovun_s16(sum);
let mut out = [0u8; 8];
safe_simd::vst1_u8(&mut out, result);
dst[row_off..row_off + 8].copy_from_slice(&out);
let dst_bytes: [u8; 8] = dst[row_off + 8..row_off + 16].try_into().unwrap();
let dst_u8 = safe_simd::vld1_u8(&dst_bytes);
let sum = vreinterpretq_s16_u16(vaddw_u8(vreinterpretq_u16_s16(v), dst_u8));
let result = vqmovun_s16(sum);
let mut out = [0u8; 8];
safe_simd::vst1_u8(&mut out, result);
dst[row_off + 8..row_off + 16].copy_from_slice(&out);
}
}
#[cfg(target_arch = "aarch64")]
#[rite(neon)]
fn dc_only_16x16_16bpc(
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i32],
bitdepth_max: i32,
) {
let dc_val = coeff[0];
coeff[0] = 0;
let scale = 2896i32 * 8;
let mut dc = ((dc_val as i64 * scale as i64 + 16384) >> 15) as i32;
dc = (dc + 2) >> 2; dc = ((dc as i64 * scale as i64 + 16384) >> 15) as i32;
dc = (dc + 8) >> 4;
let dc = dc.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
let dc_vec = vdupq_n_s16(dc);
let bd_max = vdupq_n_s16(bitdepth_max as i16);
let zero = vdupq_n_s16(0);
for i in 0..16 {
let row_off = dst_base.wrapping_add_signed(i as isize * dst_stride);
for half in 0..2 {
let off = row_off + half * 8;
let mut arr = [0i16; 8];
for j in 0..8 {
arr[j] = dst[off + j] as i16;
}
let d = safe_simd::vld1q_s16(&arr);
let sum = vqaddq_s16(d, dc_vec);
let clamped = vminq_s16(vmaxq_s16(sum, zero), bd_max);
let mut out = [0i16; 8];
safe_simd::vst1q_s16(&mut out, clamped);
for j in 0..8 {
dst[off + j] = out[j] as u16;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn inv_txfm_add_16x16_8bpc_neon(
_token: Arm64,
dst: &mut [u8],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i16],
eob: i32,
_bitdepth_max: i32,
row_tx: TxType16,
col_tx: TxType16,
) {
if matches!(row_tx, TxType16::Dct) && matches!(col_tx, TxType16::Dct) && eob == 0 {
dc_only_16x16_8bpc(dst, dst_base, dst_stride, coeff);
return;
}
let is_identity_row = matches!(row_tx, TxType16::Identity);
let mut tmp = [0i16; 256];
let eob_half = 36i32;
for half in 0..2usize {
if half == 1 && eob < eob_half {
break;
}
let row_offset = half * 8;
let zero_vec = vdupq_n_s16(0);
let mut v: V16 = [zero_vec; 16];
for col in 0..16 {
let base = col * 16 + row_offset;
let arr: [i16; 8] = coeff[base..base + 8].try_into().unwrap();
v[col] = safe_simd::vld1q_s16(&arr);
}
for col in 0..16 {
let base = col * 16 + row_offset;
coeff[base..base + 8].fill(0);
}
if is_identity_row {
let scale = vdupq_n_s16(((5793 - 4096) * 2 * 8) as i16);
for vi in v.iter_mut() {
let t = vqrdmulhq_s16(*vi, scale);
let t = vshrq_n_s16::<1>(t);
*vi = vrhaddq_s16(*vi, t);
}
} else {
v = apply_tx16(row_tx, v);
for vi in v.iter_mut() {
*vi = vrshrq_n_s16::<2>(*vi);
}
}
let top = [v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]];
let bot = [v[8], v[9], v[10], v[11], v[12], v[13], v[14], v[15]];
let (top_t, bot_t) = transpose_16x16_half(top, bot);
for i in 0..8 {
let row = half * 8 + i;
let mut arr = [0i16; 8];
safe_simd::vst1q_s16(&mut arr, top_t[i]);
tmp[row * 16..row * 16 + 8].copy_from_slice(&arr);
safe_simd::vst1q_s16(&mut arr, bot_t[i]);
tmp[row * 16 + 8..row * 16 + 16].copy_from_slice(&arr);
}
}
for half in 0..2usize {
let col_offset = half * 8;
let zero_vec = vdupq_n_s16(0);
let mut v: V16 = [zero_vec; 16];
for row in 0..16 {
let off = row * 16 + col_offset;
let arr: [i16; 8] = tmp[off..off + 8].try_into().unwrap();
v[row] = safe_simd::vld1q_s16(&arr);
}
v = apply_tx16(col_tx, v);
add_to_dst_8x16_8bpc(dst, dst_base + col_offset, dst_stride, v);
}
}
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn inv_txfm_add_16x16_16bpc_neon(
_token: Arm64,
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i32],
eob: i32,
bitdepth_max: i32,
row_tx: TxType16,
col_tx: TxType16,
) {
if matches!(row_tx, TxType16::Dct) && matches!(col_tx, TxType16::Dct) && eob == 0 {
dc_only_16x16_16bpc(dst, dst_base, dst_stride, coeff, bitdepth_max);
return;
}
let is_identity_row = matches!(row_tx, TxType16::Identity);
let mut tmp = [0i16; 256];
let eob_half = 36i32;
for half in 0..2usize {
if half == 1 && eob < eob_half {
break;
}
let row_offset = half * 8;
let zero_vec = vdupq_n_s16(0);
let mut v: V16 = [zero_vec; 16];
for col in 0..16 {
let mut arr = [0i16; 8];
for r in 0..8 {
let c = coeff[(col * 16) + row_offset + r];
arr[r] = c.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
}
v[col] = safe_simd::vld1q_s16(&arr);
}
for col in 0..16 {
for r in 0..8 {
coeff[(col * 16) + row_offset + r] = 0;
}
}
if is_identity_row {
let scale = vdupq_n_s16(((5793 - 4096) * 2 * 8) as i16);
for vi in v.iter_mut() {
let t = vqrdmulhq_s16(*vi, scale);
let t = vshrq_n_s16::<1>(t);
*vi = vrhaddq_s16(*vi, t);
}
} else {
v = apply_tx16(row_tx, v);
for vi in v.iter_mut() {
*vi = vrshrq_n_s16::<2>(*vi);
}
}
let top = [v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]];
let bot = [v[8], v[9], v[10], v[11], v[12], v[13], v[14], v[15]];
let (top_t, bot_t) = transpose_16x16_half(top, bot);
for i in 0..8 {
let row = half * 8 + i;
let mut arr = [0i16; 8];
safe_simd::vst1q_s16(&mut arr, top_t[i]);
tmp[row * 16..row * 16 + 8].copy_from_slice(&arr);
safe_simd::vst1q_s16(&mut arr, bot_t[i]);
tmp[row * 16 + 8..row * 16 + 16].copy_from_slice(&arr);
}
}
for half in 0..2usize {
let col_offset = half * 8;
let zero_vec = vdupq_n_s16(0);
let mut v: V16 = [zero_vec; 16];
for row in 0..16 {
let off = row * 16 + col_offset;
let arr: [i16; 8] = tmp[off..off + 8].try_into().unwrap();
v[row] = safe_simd::vld1q_s16(&arr);
}
v = apply_tx16(col_tx, v);
let col_off = half * 8;
add_to_dst_8x16_16bpc(dst, dst_base + col_off, dst_stride, v, bitdepth_max);
}
}
macro_rules! def_16x16_8bpc {
($name:ident, $row:expr, $col:expr) => {
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn $name(
token: Arm64,
dst: &mut [u8],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i16],
eob: i32,
bitdepth_max: i32,
) {
inv_txfm_add_16x16_8bpc_neon(
token,
dst,
dst_base,
dst_stride,
coeff,
eob,
bitdepth_max,
$row,
$col,
);
}
};
}
macro_rules! def_16x16_16bpc {
($name:ident, $row:expr, $col:expr) => {
#[cfg(target_arch = "aarch64")]
#[arcane]
pub(crate) fn $name(
token: Arm64,
dst: &mut [u16],
dst_base: usize,
dst_stride: isize,
coeff: &mut [i32],
eob: i32,
bitdepth_max: i32,
) {
inv_txfm_add_16x16_16bpc_neon(
token,
dst,
dst_base,
dst_stride,
coeff,
eob,
bitdepth_max,
$row,
$col,
);
}
};
}
def_16x16_8bpc!(
inv_txfm_add_dct_dct_16x16_8bpc_neon_inner,
TxType16::Dct,
TxType16::Dct
);
def_16x16_16bpc!(
inv_txfm_add_dct_dct_16x16_16bpc_neon_inner,
TxType16::Dct,
TxType16::Dct
);
def_16x16_8bpc!(
inv_txfm_add_identity_identity_16x16_8bpc_neon_inner,
TxType16::Identity,
TxType16::Identity
);
def_16x16_16bpc!(
inv_txfm_add_identity_identity_16x16_16bpc_neon_inner,
TxType16::Identity,
TxType16::Identity
);
def_16x16_8bpc!(
inv_txfm_add_adst_adst_16x16_8bpc_neon_inner,
TxType16::Adst,
TxType16::Adst
);
def_16x16_16bpc!(
inv_txfm_add_adst_adst_16x16_16bpc_neon_inner,
TxType16::Adst,
TxType16::Adst
);
def_16x16_8bpc!(
inv_txfm_add_dct_adst_16x16_8bpc_neon_inner,
TxType16::Dct,
TxType16::Adst
);
def_16x16_16bpc!(
inv_txfm_add_dct_adst_16x16_16bpc_neon_inner,
TxType16::Dct,
TxType16::Adst
);
def_16x16_8bpc!(
inv_txfm_add_adst_dct_16x16_8bpc_neon_inner,
TxType16::Adst,
TxType16::Dct
);
def_16x16_16bpc!(
inv_txfm_add_adst_dct_16x16_16bpc_neon_inner,
TxType16::Adst,
TxType16::Dct
);
def_16x16_8bpc!(
inv_txfm_add_dct_flipadst_16x16_8bpc_neon_inner,
TxType16::Dct,
TxType16::FlipAdst
);
def_16x16_16bpc!(
inv_txfm_add_dct_flipadst_16x16_16bpc_neon_inner,
TxType16::Dct,
TxType16::FlipAdst
);
def_16x16_8bpc!(
inv_txfm_add_flipadst_dct_16x16_8bpc_neon_inner,
TxType16::FlipAdst,
TxType16::Dct
);
def_16x16_16bpc!(
inv_txfm_add_flipadst_dct_16x16_16bpc_neon_inner,
TxType16::FlipAdst,
TxType16::Dct
);
def_16x16_8bpc!(
inv_txfm_add_flipadst_flipadst_16x16_8bpc_neon_inner,
TxType16::FlipAdst,
TxType16::FlipAdst
);
def_16x16_16bpc!(
inv_txfm_add_flipadst_flipadst_16x16_16bpc_neon_inner,
TxType16::FlipAdst,
TxType16::FlipAdst
);
def_16x16_8bpc!(
inv_txfm_add_adst_flipadst_16x16_8bpc_neon_inner,
TxType16::Adst,
TxType16::FlipAdst
);
def_16x16_16bpc!(
inv_txfm_add_adst_flipadst_16x16_16bpc_neon_inner,
TxType16::Adst,
TxType16::FlipAdst
);
def_16x16_8bpc!(
inv_txfm_add_flipadst_adst_16x16_8bpc_neon_inner,
TxType16::FlipAdst,
TxType16::Adst
);
def_16x16_16bpc!(
inv_txfm_add_flipadst_adst_16x16_16bpc_neon_inner,
TxType16::FlipAdst,
TxType16::Adst
);
def_16x16_8bpc!(
inv_txfm_add_dct_identity_16x16_8bpc_neon_inner,
TxType16::Dct,
TxType16::Identity
);
def_16x16_16bpc!(
inv_txfm_add_dct_identity_16x16_16bpc_neon_inner,
TxType16::Dct,
TxType16::Identity
);
def_16x16_8bpc!(
inv_txfm_add_identity_dct_16x16_8bpc_neon_inner,
TxType16::Identity,
TxType16::Dct
);
def_16x16_16bpc!(
inv_txfm_add_identity_dct_16x16_16bpc_neon_inner,
TxType16::Identity,
TxType16::Dct
);
def_16x16_8bpc!(
inv_txfm_add_adst_identity_16x16_8bpc_neon_inner,
TxType16::Adst,
TxType16::Identity
);
def_16x16_16bpc!(
inv_txfm_add_adst_identity_16x16_16bpc_neon_inner,
TxType16::Adst,
TxType16::Identity
);
def_16x16_8bpc!(
inv_txfm_add_identity_adst_16x16_8bpc_neon_inner,
TxType16::Identity,
TxType16::Adst
);
def_16x16_16bpc!(
inv_txfm_add_identity_adst_16x16_16bpc_neon_inner,
TxType16::Identity,
TxType16::Adst
);
def_16x16_8bpc!(
inv_txfm_add_flipadst_identity_16x16_8bpc_neon_inner,
TxType16::FlipAdst,
TxType16::Identity
);
def_16x16_16bpc!(
inv_txfm_add_flipadst_identity_16x16_16bpc_neon_inner,
TxType16::FlipAdst,
TxType16::Identity
);
def_16x16_8bpc!(
inv_txfm_add_identity_flipadst_16x16_8bpc_neon_inner,
TxType16::Identity,
TxType16::FlipAdst
);
def_16x16_16bpc!(
inv_txfm_add_identity_flipadst_16x16_16bpc_neon_inner,
TxType16::Identity,
TxType16::FlipAdst
);
#[cfg(test)]
#[cfg(target_arch = "aarch64")]
mod tests {
use super::*;
use crate::include::common::intops::iclip;
use archmage::SimdToken;
const MAX_DIFF: i32 = 40;
fn scalar_dct4_1d(input: &[i32; 4]) -> [i32; 4] {
let t3a = ((input[1] * 3784 + input[3] * 1567) + 2048) >> 12;
let t2a = ((input[1] * 1567 - input[3] * 3784) + 2048) >> 12;
let t0 = ((input[0] * 2896 + input[2] * 2896) + 2048) >> 12;
let t1 = ((input[0] * 2896 - input[2] * 2896) + 2048) >> 12;
[t0 + t3a, t1 + t2a, t1 - t2a, t0 - t3a]
}
fn scalar_dct8_1d(input: &[i32; 8]) -> [i32; 8] {
let even = [input[0], input[2], input[4], input[6]];
let even_out = scalar_dct4_1d(&even);
let t4a = ((input[1] * 799 - input[7] * 4017) + 2048) >> 12;
let t7a = ((input[1] * 4017 + input[7] * 799) + 2048) >> 12;
let t5a = ((input[5] * 3406 - input[3] * 2276) + 2048) >> 12;
let t6a = ((input[5] * 2276 + input[3] * 3406) + 2048) >> 12;
let t4 = t4a + t5a;
let t5_tmp = t4a - t5a;
let t7 = t7a + t6a;
let t6_tmp = t7a - t6a;
let t5 = ((t6_tmp * 2896 - t5_tmp * 2896) + 2048) >> 12;
let t6 = ((t6_tmp * 2896 + t5_tmp * 2896) + 2048) >> 12;
[
even_out[0] + t7,
even_out[1] + t6,
even_out[2] + t5,
even_out[3] + t4,
even_out[3] - t4,
even_out[2] - t5,
even_out[1] - t6,
even_out[0] - t7,
]
}
fn scalar_dct16_1d(input: &[i32; 16]) -> [i32; 16] {
let even: [i32; 8] = [
input[0], input[2], input[4], input[6], input[8], input[10], input[12], input[14],
];
let even_out = scalar_dct8_1d(&even);
let t8a = ((input[1] * 401 - input[15] * 4076) + 2048) >> 12;
let t15a = ((input[1] * 4076 + input[15] * 401) + 2048) >> 12;
let t9a = ((input[9] * 3166 - input[7] * 2598) + 2048) >> 12;
let t14a = ((input[9] * 2598 + input[7] * 3166) + 2048) >> 12;
let t10a = ((input[5] * 1931 - input[11] * 3612) + 2048) >> 12;
let t13a = ((input[5] * 3612 + input[11] * 1931) + 2048) >> 12;
let t11a = ((input[13] * 3920 - input[3] * 1189) + 2048) >> 12;
let t12a = ((input[13] * 1189 + input[3] * 3920) + 2048) >> 12;
let t9 = t8a - t9a;
let t8 = t8a + t9a;
let t14 = t15a - t14a;
let t15 = t15a + t14a;
let t10 = t11a - t10a;
let t11 = t11a + t10a;
let t12 = t12a + t13a;
let t13 = t12a - t13a;
let t9a = ((t14 * 1567 - t9 * 3784) + 2048) >> 12;
let t14a = ((t14 * 3784 + t9 * 1567) + 2048) >> 12;
let t13a = ((t13 * 1567 - t10 * 3784) + 2048) >> 12;
let t10a = -(((t13 * 3784 + t10 * 1567) + 2048) >> 12);
let t11a = t8 - t11;
let t8a = t8 + t11;
let t12a = t15 - t12;
let t15a = t15 + t12;
let t9b = t9a + t10a;
let t10b = t9a - t10a;
let t13b = t14a - t13a;
let t14b = t14a + t13a;
let t11f = ((t12a * 2896 - t11a * 2896) + 2048) >> 12;
let t12f = ((t12a * 2896 + t11a * 2896) + 2048) >> 12;
let t10af = ((t13b * 2896 - t10b * 2896) + 2048) >> 12;
let t13af = ((t13b * 2896 + t10b * 2896) + 2048) >> 12;
[
even_out[0] + t15a,
even_out[1] + t14b,
even_out[2] + t13af,
even_out[3] + t12f,
even_out[4] + t11f,
even_out[5] + t10af,
even_out[6] + t9b,
even_out[7] + t8a,
even_out[7] - t8a,
even_out[6] - t9b,
even_out[5] - t10af,
even_out[4] - t11f,
even_out[3] - t12f,
even_out[2] - t13af,
even_out[1] - t14b,
even_out[0] - t15a,
]
}
fn scalar_dct_dct_16x16(dst: &mut [u8], stride: isize, coeff: &mut [i16]) {
let mut tmp = [0i32; 256];
for y in 0..16 {
let mut input = [0i32; 16];
for x in 0..16 {
input[x] = coeff[y + x * 16] as i32;
}
let out = scalar_dct16_1d(&input);
for x in 0..16 {
tmp[y * 16 + x] = (out[x] + 2) >> 2; }
}
for x in 0..16 {
let mut input = [0i32; 16];
for y in 0..16 {
input[y] = tmp[y * 16 + x];
}
let out = scalar_dct16_1d(&input);
for y in 0..16 {
let row_off = (y as isize * stride) as usize;
let d = dst[row_off + x] as i32;
let c = (out[y] + 8) >> 4;
dst[row_off + x] = iclip(d + c, 0, 255) as u8;
}
}
coeff[0..256].fill(0);
}
fn scalar_identity_identity_16x16(dst: &mut [u8], stride: isize, coeff: &mut [i16]) {
let sqrt2x2 = 5793i32; for y in 0..16 {
let row_off = (y as isize * stride) as usize;
for x in 0..16 {
let c = coeff[y + x * 16] as i32;
let t = ((c as i64 * 27152 + 16384) >> 15) as i32;
let t = t >> 1;
let row_result = (c + t + 1) >> 1;
let t2 = ((row_result as i64 * 27152 + 16384) >> 15) as i32;
let col_result = row_result.saturating_mul(2).saturating_add(t2);
let final_val = (col_result + 8) >> 4;
let d = dst[row_off + x] as i32;
dst[row_off + x] = iclip(d + final_val, 0, 255) as u8;
}
}
coeff[0..256].fill(0);
}
#[test]
fn test_dct_dct_16x16_neon_vs_scalar() {
let token = archmage::Arm64::summon().expect("NEON must be available");
for pattern in 0..3 {
let mut coeff_neon = [0i16; 256];
let mut coeff_scalar = [0i16; 256];
match pattern {
0 => {
coeff_neon[0] = 1000;
coeff_scalar[0] = 1000;
}
1 => {
for i in 0..16 {
coeff_neon[i] = ((i as i16 + 1) * 100) % 2000 - 1000;
coeff_scalar[i] = coeff_neon[i];
}
}
2 => {
for i in 0..256 {
let val = ((i * 37 + 13) % 2001) as i16 - 1000;
coeff_neon[i] = val;
coeff_scalar[i] = val;
}
}
_ => unreachable!(),
}
let stride = 16isize;
let mut dst_neon = [128u8; 16 * 16];
let mut dst_scalar = [128u8; 16 * 16];
inv_txfm_add_16x16_8bpc_neon(
token,
&mut dst_neon,
0,
stride,
&mut coeff_neon,
255,
255,
TxType16::Dct,
TxType16::Dct,
);
scalar_dct_dct_16x16(&mut dst_scalar, stride, &mut coeff_scalar);
let mut max_diff_seen = 0i32;
for i in 0..256 {
let diff = (dst_neon[i] as i32 - dst_scalar[i] as i32).abs();
max_diff_seen = max_diff_seen.max(diff);
}
assert!(
max_diff_seen <= MAX_DIFF,
"DCT_DCT 16x16 pattern {pattern}: max diff = {max_diff_seen} (expected <= {MAX_DIFF})"
);
}
}
#[test]
fn test_identity_identity_16x16_neon_vs_scalar() {
let token = archmage::Arm64::summon().expect("NEON must be available");
for pattern in 0..3 {
let mut coeff_neon = [0i16; 256];
let mut coeff_scalar = [0i16; 256];
match pattern {
0 => {
coeff_neon[0] = 500;
coeff_scalar[0] = 500;
}
1 => {
for i in 0..32 {
coeff_neon[i] = ((i as i16 + 1) * 50) % 1000 - 500;
coeff_scalar[i] = coeff_neon[i];
}
}
2 => {
for i in 0..256 {
let val = ((i * 31 + 7) % 1001) as i16 - 500;
coeff_neon[i] = val;
coeff_scalar[i] = val;
}
}
_ => unreachable!(),
}
let stride = 16isize;
let mut dst_neon = [128u8; 16 * 16];
let mut dst_scalar = [128u8; 16 * 16];
inv_txfm_add_16x16_8bpc_neon(
token,
&mut dst_neon,
0,
stride,
&mut coeff_neon,
255,
255,
TxType16::Identity,
TxType16::Identity,
);
scalar_identity_identity_16x16(&mut dst_scalar, stride, &mut coeff_scalar);
let mut max_diff_seen = 0i32;
for i in 0..256 {
let diff = (dst_neon[i] as i32 - dst_scalar[i] as i32).abs();
max_diff_seen = max_diff_seen.max(diff);
}
assert!(
max_diff_seen <= MAX_DIFF,
"IDENTITY_IDENTITY 16x16 pattern {pattern}: max diff = {max_diff_seen} (expected <= {MAX_DIFF})"
);
}
}
#[test]
fn test_dc_only_16x16() {
let token = archmage::Arm64::summon().expect("NEON must be available");
let mut coeff = [0i16; 256];
coeff[0] = 1000;
let stride = 16isize;
let mut dst = [128u8; 16 * 16];
inv_txfm_add_16x16_8bpc_neon(
token,
&mut dst,
0,
stride,
&mut coeff,
0,
255,
TxType16::Dct,
TxType16::Dct,
);
assert_eq!(coeff[0], 0);
let first = dst[0];
for i in 1..256 {
assert_eq!(dst[i], first, "DC output not uniform at pixel {i}");
}
}
}