use super::constants::*;
use crate::vardct::common::{as_array_mut, as_array_ref};
#[inline]
pub fn idct1d_2(mem: &mut [f32]) {
let x = mem[0];
let y = mem[1];
mem[0] = (x + y) * 0.5;
mem[1] = (x - y) * 0.5;
}
#[inline(always)]
fn idct1d_4_val(a: f32, b: f32, c: f32, d: f32) -> [f32; 4] {
let odd0 = (c - d) * INV_SQRT2;
let o0_pre = (odd0 + d) * 0.5;
let o1_pre = (odd0 - d) * 0.5;
let o0 = o0_pre * INV_WC_MULTIPLIERS_4[0];
let o1 = o1_pre * INV_WC_MULTIPLIERS_4[1];
let e0 = (a + b) * 0.5;
let e1 = (a - b) * 0.5;
[
(e0 + o0) * 0.5,
(e1 + o1) * 0.5,
(e1 - o1) * 0.5,
(e0 - o0) * 0.5,
]
}
#[inline(always)]
pub fn idct1d_4(mem: &mut [f32]) {
let r = idct1d_4_val(mem[0], mem[2], mem[1], mem[3]);
mem[..4].copy_from_slice(&r);
}
#[inline(always)]
fn idct1d_8_core_val(m: [f32; 8]) -> [f32; 8] {
let (e0, e1, e2, e3) = (m[0], m[2], m[4], m[6]);
let (mut o0, mut o1, mut o2, o3) = (m[1], m[3], m[5], m[7]);
o2 -= o3;
o1 -= o2;
o0 = (o0 - o1) * INV_SQRT2;
let odd = idct1d_4_val(o0, o2, o1, o3);
let o = [
odd[0] * INV_WC_MULTIPLIERS_8[0],
odd[1] * INV_WC_MULTIPLIERS_8[1],
odd[2] * INV_WC_MULTIPLIERS_8[2],
odd[3] * INV_WC_MULTIPLIERS_8[3],
];
let e = idct1d_4_val(e0, e2, e1, e3);
[
(e[0] + o[0]) * 0.5,
(e[1] + o[1]) * 0.5,
(e[2] + o[2]) * 0.5,
(e[3] + o[3]) * 0.5,
(e[3] - o[3]) * 0.5,
(e[2] - o[2]) * 0.5,
(e[1] - o[1]) * 0.5,
(e[0] - o[0]) * 0.5,
]
}
fn idct1d_8_core(mem: &mut [f32]) {
let mut m = [0.0f32; 8];
m.copy_from_slice(&mem[..8]);
let r = idct1d_8_core_val(m);
mem[..8].copy_from_slice(&r);
}
pub fn idct1d_8(mem: &mut [f32]) {
let m = [
mem[0] * 8.0,
mem[1] * 8.0,
mem[2] * 8.0,
mem[3] * 8.0,
mem[4] * 8.0,
mem[5] * 8.0,
mem[6] * 8.0,
mem[7] * 8.0,
];
let r = idct1d_8_core_val(m);
mem[..8].copy_from_slice(&r);
}
pub fn idct1d_16(mem: &mut [f32]) {
let mut even = [0.0f32; 8];
let mut odd = [0.0f32; 8];
for i in 0..8 {
even[i] = mem[2 * i] * 16.0;
odd[i] = mem[2 * i + 1] * 16.0;
}
for i in (1..7).rev() {
odd[i] -= odd[i + 1];
}
odd[0] = (odd[0] - odd[1]) * INV_SQRT2;
let odd_r = idct1d_8_core_val(odd);
let mut odd_scaled = [0.0f32; 8];
for i in 0..8 {
odd_scaled[i] = odd_r[i] * INV_WC_MULTIPLIERS_16[i];
}
let even_r = idct1d_8_core_val(even);
for i in 0..8 {
mem[i] = (even_r[i] + odd_scaled[i]) * 0.5;
mem[15 - i] = (even_r[i] - odd_scaled[i]) * 0.5;
}
}
#[allow(clippy::needless_range_loop)]
fn idct1d_8_ref(input: &[f32], output: &mut [f32]) {
let n = 8usize;
let pi = core::f32::consts::PI;
for k in 0..n {
let mut sum = 0.5 * input[0];
for j in 1..n {
let angle = pi * (j as f32) * ((2 * k + 1) as f32) / (2.0 * n as f32);
sum += input[j] * angle.cos();
}
output[k] = sum;
}
}
#[inline]
pub fn idct_8x8(input: &[f32; 64], output: &mut [f32; 64]) {
jxl_simd::idct_8x8(input, output);
}
#[inline]
pub fn idct_16x16(input: &[f32; 256], output: &mut [f32; 256]) {
jxl_simd::idct_16x16(input, output);
}
#[inline]
pub fn idct_16x8(input: &[f32; 128], output: &mut [f32; 128]) {
jxl_simd::idct_16x8(input, output);
}
#[inline]
pub fn idct_8x16(input: &[f32; 128], output: &mut [f32; 128]) {
jxl_simd::idct_8x16(input, output);
}
#[inline(always)]
pub fn idct_4x4(input: &[f32; 16], output: &mut [f32; 16]) {
let mut temp = [0.0f32; 16];
for row in 0..4 {
let s = row * 4;
let r = idct1d_4_val(
input[s] * 4.0,
input[s + 2] * 4.0,
input[s + 1] * 4.0,
input[s + 3] * 4.0,
);
for col in 0..4 {
temp[col * 4 + row] = r[col];
}
}
for row in 0..4 {
let s = row * 4;
let r = idct1d_4_val(
temp[s] * 4.0,
temp[s + 2] * 4.0,
temp[s + 1] * 4.0,
temp[s + 3] * 4.0,
);
output[s..s + 4].copy_from_slice(&r);
}
}
#[inline(always)]
pub fn idct_4x8(input: &[f32; 32], output: &mut [f32; 32]) {
let mut temp = [0.0f32; 32];
for col in 0..8 {
let a = input[col] * 4.0;
let b = input[8 + col] * 4.0;
let c = input[16 + col] * 4.0;
let d = input[24 + col] * 4.0;
let r = idct1d_4_val(a, c, b, d);
for row in 0..4 {
temp[row * 8 + col] = r[row];
}
}
for row in 0..4 {
let s = row * 8;
let m = [
temp[s] * 8.0,
temp[s + 1] * 8.0,
temp[s + 2] * 8.0,
temp[s + 3] * 8.0,
temp[s + 4] * 8.0,
temp[s + 5] * 8.0,
temp[s + 6] * 8.0,
temp[s + 7] * 8.0,
];
let r = idct1d_8_core_val(m);
output[s..s + 8].copy_from_slice(&r);
}
}
#[inline(always)]
pub fn idct_8x4(input: &[f32; 32], output: &mut [f32; 32]) {
let mut temp = [0.0f32; 32];
for row in 0..4 {
let s = row * 8;
let m = [
input[s] * 8.0,
input[s + 1] * 8.0,
input[s + 2] * 8.0,
input[s + 3] * 8.0,
input[s + 4] * 8.0,
input[s + 5] * 8.0,
input[s + 6] * 8.0,
input[s + 7] * 8.0,
];
let r = idct1d_8_core_val(m);
for col in 0..8 {
temp[col * 4 + row] = r[col];
}
}
for row in 0..8 {
let s = row * 4;
let r = idct1d_4_val(
temp[s] * 4.0,
temp[s + 2] * 4.0,
temp[s + 1] * 4.0,
temp[s + 3] * 4.0,
);
output[s..s + 4].copy_from_slice(&r);
}
}
fn idct1d_16_core(mem: &mut [f32]) {
let mut even = [0.0f32; 8];
let mut odd = [0.0f32; 8];
for i in 0..8 {
even[i] = mem[2 * i];
odd[i] = mem[2 * i + 1];
}
for i in (1..7).rev() {
odd[i] -= odd[i + 1];
}
odd[0] = (odd[0] - odd[1]) * INV_SQRT2;
let odd_r = idct1d_8_core_val(odd);
let mut odd_scaled = [0.0f32; 8];
for i in 0..8 {
odd_scaled[i] = odd_r[i] * INV_WC_MULTIPLIERS_16[i];
}
let even_r = idct1d_8_core_val(even);
for i in 0..8 {
mem[i] = (even_r[i] + odd_scaled[i]) * 0.5;
mem[15 - i] = (even_r[i] - odd_scaled[i]) * 0.5;
}
}
fn idct1d_32(mem: &mut [f32]) {
for x in mem.iter_mut().take(32) {
*x *= 32.0;
}
idct1d_32_core(mem);
}
fn idct1d_32_core(mem: &mut [f32]) {
let mut even = [0.0f32; 16];
let mut odd = [0.0f32; 16];
for i in 0..16 {
even[i] = mem[2 * i];
odd[i] = mem[2 * i + 1];
}
for i in (1..15).rev() {
odd[i] -= odd[i + 1];
}
odd[0] = (odd[0] - odd[1]) * INV_SQRT2;
idct1d_16_core(&mut odd);
for i in 0..16 {
odd[i] *= INV_WC_MULTIPLIERS_32[i];
}
idct1d_16_core(&mut even);
for i in 0..16 {
mem[i] = (even[i] + odd[i]) * 0.5;
mem[31 - i] = (even[i] - odd[i]) * 0.5;
}
}
pub fn idct_32x32(input: &[f32; 1024], output: &mut [f32; 1024]) {
jxl_simd::idct_32x32(input, output);
}
pub fn idct_32x16(input: &[f32; 512], output: &mut [f32; 512]) {
jxl_simd::idct_32x16(input, output);
}
pub fn idct_16x32(input: &[f32; 512], output: &mut [f32; 512]) {
jxl_simd::idct_16x32(input, output);
}
fn idct1d_64(mem: &mut [f32]) {
for x in mem.iter_mut().take(64) {
*x *= 64.0;
}
idct1d_64_core(mem);
}
fn idct1d_64_core(mem: &mut [f32]) {
let mut tmp = [0.0f32; 64];
for i in 0..32 {
tmp[i] = mem[2 * i];
tmp[32 + i] = mem[2 * i + 1];
}
for i in (1..31).rev() {
tmp[32 + i] -= tmp[32 + i + 1];
}
tmp[32] = (tmp[32] - tmp[33]) * INV_SQRT2;
idct1d_32_core(&mut tmp[32..64]);
for i in 0..32 {
tmp[32 + i] *= INV_WC_MULTIPLIERS_64[i];
}
idct1d_32_core(&mut tmp[0..32]);
for i in 0..32 {
mem[i] = (tmp[i] + tmp[32 + i]) * 0.5;
mem[63 - i] = (tmp[i] - tmp[32 + i]) * 0.5;
}
}
pub fn idct_64x64(input: &[f32], output: &mut [f32]) {
debug_assert!(input.len() >= 4096);
debug_assert!(output.len() >= 4096);
jxl_simd::idct_64x64(as_array_ref(input, 0), as_array_mut(output, 0));
}
pub fn idct_64x32(input: &[f32], output: &mut [f32]) {
debug_assert!(input.len() >= 2048);
debug_assert!(output.len() >= 2048);
jxl_simd::idct_64x32(as_array_ref(input, 0), as_array_mut(output, 0));
}
pub fn idct_32x64(input: &[f32], output: &mut [f32]) {
debug_assert!(input.len() >= 2048);
debug_assert!(output.len() >= 2048);
jxl_simd::idct_32x64(as_array_ref(input, 0), as_array_mut(output, 0));
}
#[allow(clippy::needless_range_loop)]
fn idct1d_n_ref(input: &[f32], output: &mut [f32], n: usize) {
let pi = core::f32::consts::PI;
for k in 0..n {
let mut sum = 0.5 * input[0];
for j in 1..n {
let angle = pi * (j as f32) * ((2 * k + 1) as f32) / (2.0 * n as f32);
sum += input[j] * angle.cos();
}
output[k] = sum;
}
}
#[inline]
pub fn dc_from_dct_8x8(coeffs: &[f32; 64]) -> f32 {
coeffs[0]
}
pub fn dc_from_dct_16x8(coeffs: &[f32; 128]) -> [f32; 2] {
let lf0 = coeffs[0] * DCT_RESAMPLE_SCALE_16_TO_2[0];
let lf1 = coeffs[1] * DCT_RESAMPLE_SCALE_16_TO_2[1];
[lf0 + lf1, lf0 - lf1]
}
pub fn dc_from_dct_8x16(coeffs: &[f32; 128]) -> [f32; 2] {
let lf0 = coeffs[0] * DCT_RESAMPLE_SCALE_16_TO_2[0];
let lf1 = coeffs[1] * DCT_RESAMPLE_SCALE_16_TO_2[1];
[lf0 + lf1, lf0 - lf1]
}
fn idct1d_4_ref(input: &[f32; 4], output: &mut [f32; 4]) {
let c1 = core::f32::consts::FRAC_PI_8.cos(); let c3 = (3.0 * core::f32::consts::FRAC_PI_8).cos();
let x0 = input[0];
let x1 = input[1];
let x2 = input[2];
let x3 = input[3];
output[0] = x0
+ 2.0
* (x1 * (core::f32::consts::PI * 1.0 / 8.0).cos()
+ x2 * (core::f32::consts::PI * 2.0 / 8.0).cos()
+ x3 * (core::f32::consts::PI * 3.0 / 8.0).cos());
output[1] = x0
+ 2.0
* (x1 * (core::f32::consts::PI * 3.0 / 8.0).cos()
+ x2 * (core::f32::consts::PI * 6.0 / 8.0).cos()
+ x3 * (core::f32::consts::PI * 9.0 / 8.0).cos());
output[2] = x0
+ 2.0
* (x1 * (core::f32::consts::PI * 5.0 / 8.0).cos()
+ x2 * (core::f32::consts::PI * 10.0 / 8.0).cos()
+ x3 * (core::f32::consts::PI * 15.0 / 8.0).cos());
output[3] = x0
+ 2.0
* (x1 * (core::f32::consts::PI * 7.0 / 8.0).cos()
+ x2 * (core::f32::consts::PI * 14.0 / 8.0).cos()
+ x3 * (core::f32::consts::PI * 21.0 / 8.0).cos());
let _ = (c1, c3);
}
#[inline(always)]
pub fn idct_4x8_full(input: &[f32; 64], output: &mut [f32; 64]) {
jxl_simd::idct_4x8_full(input, output);
}
#[inline(always)]
pub fn idct_8x4_full(input: &[f32; 64], output: &mut [f32; 64]) {
jxl_simd::idct_8x4_full(input, output);
}
#[inline(always)]
pub fn idct_4x4_full(input: &[f32; 64], output: &mut [f32; 64]) {
jxl_simd::idct_4x4_full(input, output);
}