#![allow(clippy::approx_constant)] #![allow(clippy::items_after_statements)] #![allow(clippy::unreadable_literal)]
use crate::dft::codelets::simd;
use crate::kernel::{Complex, Float};
#[inline]
#[allow(clippy::excessive_precision)]
pub(super) fn dft5<T: Float>(input: &[Complex<T>; 5], sign_t: T) -> [Complex<T>; 5] {
let c1 = T::from_f64(0.30901699437494742);
let c2 = T::from_f64(-0.80901699437494742);
let s1 = T::from_f64(0.95105651629515357);
let s2 = T::from_f64(0.58778525229247313);
let x0 = input[0];
let x1 = input[1];
let x2 = input[2];
let x3 = input[3];
let x4 = input[4];
let p1 = x1 + x4;
let p2 = x2 + x3;
let m1 = x1 - x4;
let m2 = x2 - x3;
let y0 = x0 + p1 + p2;
let t1 = x0 + p1.scale(c1) + p2.scale(c2);
let t2 = x0 + p1.scale(c2) + p2.scale(c1);
let r1_re = -sign_t * (m1.im * s1 + m2.im * s2);
let r1_im = sign_t * (m1.re * s1 + m2.re * s2);
let r2_re = -sign_t * (m1.im * s2 - m2.im * s1);
let r2_im = sign_t * (m1.re * s2 - m2.re * s1);
let y1 = Complex::new(t1.re + r1_re, t1.im + r1_im);
let y4 = Complex::new(t1.re - r1_re, t1.im - r1_im);
let y2 = Complex::new(t2.re + r2_re, t2.im + r2_im);
let y3 = Complex::new(t2.re - r2_re, t2.im - r2_im);
[y0, y1, y2, y3, y4]
}
#[inline]
pub(super) fn dft9<T: Float>(input: &[Complex<T>; 9], sign_t: T) -> [Complex<T>; 9] {
let sqrt3_2 = T::from_f64(0.8660254037844387);
let half = T::from_f64(0.5);
let mut t: [Complex<T>; 9] = [Complex::zero(); 9];
for j in 0..3 {
let a0 = input[j];
let a1 = input[j + 3];
let a2 = input[j + 6];
let sum = a0 + a1 + a2;
let d1 = a1 - a2;
let d2 = a0 - (a1 + a2).scale(half);
let rot_re = -sign_t * sqrt3_2 * d1.im;
let rot_im = -sign_t * sqrt3_2 * d1.re;
t[j] = sum;
t[j + 3] = Complex::new(d2.re + rot_re, d2.im - rot_im);
t[j + 6] = Complex::new(d2.re - rot_re, d2.im + rot_im);
}
let c1 = T::from_f64(0.766044443118978);
let s1 = T::from_f64(0.6427876096865393);
let c2 = T::from_f64(0.17364817766693041);
let s2 = T::from_f64(0.984807753012208);
let c4 = T::from_f64(-0.9396926207859084);
let s4 = T::from_f64(0.3420201433256687);
let tw = t[4];
t[4] = Complex::new(
tw.re * c1 - sign_t * tw.im * s1,
sign_t * tw.re * s1 + tw.im * c1,
);
let tw = t[5];
t[5] = Complex::new(
tw.re * c2 - sign_t * tw.im * s2,
sign_t * tw.re * s2 + tw.im * c2,
);
let tw = t[7];
t[7] = Complex::new(
tw.re * c2 - sign_t * tw.im * s2,
sign_t * tw.re * s2 + tw.im * c2,
);
let tw = t[8];
t[8] = Complex::new(
tw.re * c4 - sign_t * tw.im * s4,
sign_t * tw.re * s4 + tw.im * c4,
);
let mut y: [Complex<T>; 9] = [Complex::zero(); 9];
for k1 in 0..3 {
let base = k1 * 3;
let a0 = t[base];
let a1 = t[base + 1];
let a2 = t[base + 2];
let sum = a0 + a1 + a2;
let d1 = a1 - a2;
let d2 = a0 - (a1 + a2).scale(half);
let rot_re = -sign_t * sqrt3_2 * d1.im;
let rot_im = -sign_t * sqrt3_2 * d1.re;
y[k1] = sum;
y[k1 + 3] = Complex::new(d2.re + rot_re, d2.im - rot_im);
y[k1 + 6] = Complex::new(d2.re - rot_re, d2.im + rot_im);
}
y
}
#[inline]
pub(super) fn dft25<T: Float>(input: &[Complex<T>; 25], sign_t: T) -> [Complex<T>; 25] {
let mut t: [Complex<T>; 25] = [Complex::zero(); 25];
for j in 0..5 {
let col_input: [Complex<T>; 5] = [
input[j],
input[j + 5],
input[j + 10],
input[j + 15],
input[j + 20],
];
let col_output = dft5(&col_input, sign_t);
for k1 in 0..5 {
t[k1 * 5 + j] = col_output[k1];
}
}
let cos_25: [T; 17] = [
T::ONE,
T::from_f64(0.9685831611286311),
T::from_f64(0.8763066800438637),
T::from_f64(0.7289686274214116),
T::from_f64(0.5358267949789965),
T::from_f64(0.30901699437494745),
T::from_f64(0.06279051952931337),
T::from_f64(-0.18738131458572463),
T::from_f64(-0.4257792915650727),
T::from_f64(-0.6374239897486897),
T::from_f64(-0.8090169943749474),
T::from_f64(-0.9297764858882513),
T::from_f64(-0.9921147013144779),
T::from_f64(-0.9921147013144779),
T::from_f64(-0.9297764858882513),
T::from_f64(-0.8090169943749474),
T::from_f64(-0.6374239897486897),
];
let sin_25: [T; 17] = [
T::ZERO,
T::from_f64(0.2486898871648548),
T::from_f64(0.4817536741017153),
T::from_f64(0.6845471059286887),
T::from_f64(0.8443279255020151),
T::from_f64(0.9510565162951535),
T::from_f64(0.998026728428272),
T::from_f64(0.9822872507286887),
T::from_f64(0.9048270524660195),
T::from_f64(0.7705132427757893),
T::from_f64(0.5877852522924731),
T::from_f64(0.3681245526846779),
T::from_f64(0.12533323356430426),
T::from_f64(-0.12533323356430426),
T::from_f64(-0.3681245526846779),
T::from_f64(-0.5877852522924731),
T::from_f64(-0.7705132427757893),
];
for k1 in 1..5 {
for k2 in 1..5 {
let k = k1 * k2;
let c = cos_25[k];
let s = sin_25[k];
let idx = k1 * 5 + k2;
let tw = t[idx];
t[idx] = Complex::new(
tw.re * c - sign_t * tw.im * s,
sign_t * tw.re * s + tw.im * c,
);
}
}
let mut y: [Complex<T>; 25] = [Complex::zero(); 25];
for k1 in 0..5 {
let base = k1 * 5;
let row_input: [Complex<T>; 5] =
[t[base], t[base + 1], t[base + 2], t[base + 3], t[base + 4]];
let row_output = dft5(&row_input, sign_t);
for k2 in 0..5 {
y[k2 * 5 + k1] = row_output[k2];
}
}
y
}
#[inline]
pub fn notw_12<T: Float>(x: &mut [Complex<T>], sign: i32) {
debug_assert!(x.len() >= 12);
let sqrt3_2 = T::from_f64(0.8660254037844387);
let half = T::from_f64(0.5);
let sign_t = if sign < 0 { -T::ONE } else { T::ONE };
let w1_re = sqrt3_2;
let w1_im = half * sign_t;
let w2_re = half;
let w2_im = sqrt3_2 * sign_t;
let w3_re = T::ZERO;
let w3_im = T::ONE * sign_t;
let w4_re = -half;
let w4_im = sqrt3_2 * sign_t;
let mut t: [Complex<T>; 12] = [Complex::zero(); 12];
for j in 0..4 {
let a0 = x[j];
let a1 = x[j + 4];
let a2 = x[j + 8];
let sum = a0 + a1 + a2;
let d1 = a1 - a2;
let d2 = a0 - (a1 + a2).scale(half);
let rot_re = -sign_t * sqrt3_2 * d1.im;
let rot_im = -sign_t * sqrt3_2 * d1.re;
t[j] = sum;
t[j + 4] = Complex::new(d2.re + rot_re, d2.im - rot_im);
t[j + 8] = Complex::new(d2.re - rot_re, d2.im + rot_im);
}
let tw = t[5];
t[5] = Complex::new(tw.re * w1_re - tw.im * w1_im, tw.re * w1_im + tw.im * w1_re);
let tw = t[6];
t[6] = Complex::new(tw.re * w2_re - tw.im * w2_im, tw.re * w2_im + tw.im * w2_re);
let tw = t[7];
t[7] = Complex::new(tw.re * w3_re - tw.im * w3_im, tw.re * w3_im + tw.im * w3_re);
let tw = t[9];
t[9] = Complex::new(tw.re * w2_re - tw.im * w2_im, tw.re * w2_im + tw.im * w2_re);
let tw = t[10];
t[10] = Complex::new(tw.re * w4_re - tw.im * w4_im, tw.re * w4_im + tw.im * w4_re);
t[11] = Complex::new(-t[11].re, -t[11].im);
for k1 in 0..3 {
let base = k1 * 4;
let a0 = t[base];
let a1 = t[base + 1];
let a2 = t[base + 2];
let a3 = t[base + 3];
let s02 = a0 + a2;
let d02 = a0 - a2;
let s13 = a1 + a3;
let d13 = a1 - a3;
let rot_d13 = Complex::new(sign_t * d13.im, -sign_t * d13.re);
x[k1] = s02 + s13;
x[k1 + 3] = d02 - rot_d13;
x[k1 + 6] = s02 - s13;
x[k1 + 9] = d02 + rot_d13;
}
}
#[inline]
pub fn notw_24<T: Float>(x: &mut [Complex<T>], sign: i32) {
debug_assert!(x.len() >= 24);
let sqrt3_2 = T::from_f64(0.8660254037844387);
let half = T::from_f64(0.5);
let sign_t = if sign < 0 { -T::ONE } else { T::ONE };
let c1 = T::from_f64(0.9659258262890683);
let s1 = T::from_f64(0.25881904510252074);
let c2 = sqrt3_2;
let s2 = half;
let c3 = T::from_f64(0.7071067811865476);
let s3 = c3;
let c4 = half;
let s4 = sqrt3_2;
let c5 = s1;
let s5 = c1;
let mut t: [Complex<T>; 24] = [Complex::zero(); 24];
for j in 0..8 {
let a0 = x[j];
let a1 = x[j + 8];
let a2 = x[j + 16];
let sum = a0 + a1 + a2;
let d1 = a1 - a2;
let d2 = a0 - (a1 + a2).scale(half);
let rot_re = -sign_t * sqrt3_2 * d1.im;
let rot_im = -sign_t * sqrt3_2 * d1.re;
t[j] = sum;
t[j + 8] = Complex::new(d2.re + rot_re, d2.im - rot_im);
t[j + 16] = Complex::new(d2.re - rot_re, d2.im + rot_im);
}
let tw = t[9];
t[9] = Complex::new(
tw.re * c1 - sign_t * tw.im * s1,
sign_t * tw.re * s1 + tw.im * c1,
);
let tw = t[10];
t[10] = Complex::new(
tw.re * c2 - sign_t * tw.im * s2,
sign_t * tw.re * s2 + tw.im * c2,
);
let tw = t[11];
t[11] = Complex::new(
tw.re * c3 - sign_t * tw.im * s3,
sign_t * tw.re * s3 + tw.im * c3,
);
let tw = t[12];
t[12] = Complex::new(
tw.re * c4 - sign_t * tw.im * s4,
sign_t * tw.re * s4 + tw.im * c4,
);
let tw = t[13];
t[13] = Complex::new(
tw.re * c5 - sign_t * tw.im * s5,
sign_t * tw.re * s5 + tw.im * c5,
);
let tw = t[14];
t[14] = Complex::new(-sign_t * tw.im, sign_t * tw.re);
let tw = t[15];
t[15] = Complex::new(
-tw.re * s1 - sign_t * tw.im * c1,
sign_t * tw.re * c1 - tw.im * s1,
);
let tw = t[17];
t[17] = Complex::new(
tw.re * c2 - sign_t * tw.im * s2,
sign_t * tw.re * s2 + tw.im * c2,
);
let tw = t[18];
t[18] = Complex::new(
tw.re * c4 - sign_t * tw.im * s4,
sign_t * tw.re * s4 + tw.im * c4,
);
let tw = t[19];
t[19] = Complex::new(-sign_t * tw.im, sign_t * tw.re);
let tw = t[20];
t[20] = Complex::new(
-tw.re * c4 - sign_t * tw.im * s4,
sign_t * tw.re * s4 - tw.im * c4,
);
let tw = t[21];
t[21] = Complex::new(
-tw.re * c2 - sign_t * tw.im * s2,
sign_t * tw.re * s2 - tw.im * c2,
);
t[22] = -t[22];
let tw = t[23];
t[23] = Complex::new(
-tw.re * c2 + sign_t * tw.im * s2,
-sign_t * tw.re * s2 - tw.im * c2,
);
for k1 in 0..3 {
let base = k1 * 8;
simd::notw_8_dispatch(&mut t[base..base + 8], sign);
}
for k1 in 0..3 {
let base = k1 * 8;
for k2 in 0..8 {
x[k2 * 3 + k1] = t[base + k2];
}
}
}
#[inline]
pub fn notw_36<T: Float>(x: &mut [Complex<T>], sign: i32) {
debug_assert!(x.len() >= 36);
let sign_t = if sign < 0 { -T::ONE } else { T::ONE };
let mut t: [Complex<T>; 36] = [Complex::zero(); 36];
for j in 0..9 {
let a0 = x[j];
let a1 = x[j + 9];
let a2 = x[j + 18];
let a3 = x[j + 27];
let s02 = a0 + a2;
let d02 = a0 - a2;
let s13 = a1 + a3;
let d13 = a1 - a3;
let rot_d13 = Complex::new(sign_t * d13.im, -sign_t * d13.re);
t[j] = s02 + s13;
t[j + 9] = d02 - rot_d13;
t[j + 18] = s02 - s13;
t[j + 27] = d02 + rot_d13;
}
let c1 = T::from_f64(0.984807753012208);
let s1 = T::from_f64(0.17364817766693033);
let c2 = T::from_f64(0.9396926207859084);
let s2 = T::from_f64(0.3420201433256687);
let c3 = T::from_f64(0.8660254037844387);
let s3 = T::from_f64(0.5);
let c4 = T::from_f64(0.766044443118978);
let s4 = T::from_f64(0.6427876096865393);
let c5 = T::from_f64(0.6427876096865394);
let s5 = T::from_f64(0.766044443118978);
let c6 = T::from_f64(0.5);
let s6 = T::from_f64(0.8660254037844387);
let c7 = T::from_f64(0.3420201433256688);
let s7 = T::from_f64(0.9396926207859084);
let c8 = T::from_f64(0.17364817766693041);
let s8 = T::from_f64(0.984807753012208);
let c10 = T::from_f64(-0.17364817766693033);
let s10 = T::from_f64(0.984807753012208);
let c12 = T::from_f64(-0.5);
let s12 = T::from_f64(0.8660254037844387);
let c14 = T::from_f64(-0.766044443118978);
let s14 = T::from_f64(0.6427876096865394);
let c15 = T::from_f64(-0.8660254037844387);
let s15 = T::from_f64(0.5);
let c16 = T::from_f64(-0.9396926207859084);
let s16 = T::from_f64(0.3420201433256687);
let c21 = T::from_f64(-0.8660254037844387);
let s21 = T::from_f64(-0.5);
let c24 = T::from_f64(-0.5);
let s24 = T::from_f64(-0.8660254037844387);
#[inline]
fn apply_tw<T: Float>(tw: Complex<T>, c: T, s: T, sign_t: T) -> Complex<T> {
Complex::new(
tw.re * c - sign_t * tw.im * s,
sign_t * tw.re * s + tw.im * c,
)
}
t[10] = apply_tw(t[10], c1, s1, sign_t);
t[11] = apply_tw(t[11], c2, s2, sign_t);
t[12] = apply_tw(t[12], c3, s3, sign_t);
t[13] = apply_tw(t[13], c4, s4, sign_t);
t[14] = apply_tw(t[14], c5, s5, sign_t);
t[15] = apply_tw(t[15], c6, s6, sign_t);
t[16] = apply_tw(t[16], c7, s7, sign_t);
t[17] = apply_tw(t[17], c8, s8, sign_t);
t[19] = apply_tw(t[19], c2, s2, sign_t);
t[20] = apply_tw(t[20], c4, s4, sign_t);
t[21] = apply_tw(t[21], c6, s6, sign_t);
t[22] = apply_tw(t[22], c8, s8, sign_t);
t[23] = apply_tw(t[23], c10, s10, sign_t);
t[24] = apply_tw(t[24], c12, s12, sign_t);
t[25] = apply_tw(t[25], c14, s14, sign_t);
t[26] = apply_tw(t[26], c16, s16, sign_t);
t[28] = apply_tw(t[28], c3, s3, sign_t);
t[29] = apply_tw(t[29], c6, s6, sign_t);
let tw = t[30];
t[30] = Complex::new(-sign_t * tw.im, sign_t * tw.re);
t[31] = apply_tw(t[31], c12, s12, sign_t);
t[32] = apply_tw(t[32], c15, s15, sign_t);
t[33] = -t[33];
t[34] = apply_tw(t[34], c21, s21, sign_t);
t[35] = apply_tw(t[35], c24, s24, sign_t);
for k1 in 0..4 {
let base = k1 * 9;
let a: [Complex<T>; 9] = [
t[base],
t[base + 1],
t[base + 2],
t[base + 3],
t[base + 4],
t[base + 5],
t[base + 6],
t[base + 7],
t[base + 8],
];
let y = dft9(&a, sign_t);
for k2 in 0..9 {
x[k2 * 4 + k1] = y[k2];
}
}
}
#[inline]
pub fn notw_48<T: Float>(x: &mut [Complex<T>], sign: i32) {
debug_assert!(x.len() >= 48);
let sqrt3_2 = T::from_f64(0.8660254037844387);
let half = T::from_f64(0.5);
let sign_t = if sign < 0 { -T::ONE } else { T::ONE };
let mut t: [Complex<T>; 48] = [Complex::zero(); 48];
for j in 0..16 {
let a0 = x[j];
let a1 = x[j + 16];
let a2 = x[j + 32];
let sum = a0 + a1 + a2;
let d1 = a1 - a2;
let d2 = a0 - (a1 + a2).scale(half);
let rot_re = -sign_t * sqrt3_2 * d1.im;
let rot_im = -sign_t * sqrt3_2 * d1.re;
t[j] = sum;
t[j + 16] = Complex::new(d2.re + rot_re, d2.im - rot_im);
t[j + 32] = Complex::new(d2.re - rot_re, d2.im + rot_im);
}
let cos_48: [T; 31] = [
T::ONE,
T::from_f64(0.9914448613738104),
T::from_f64(0.9659258262890683),
T::from_f64(0.9238795325112867),
T::from_f64(0.8660254037844387),
T::from_f64(0.7933533402912352),
T::from_f64(0.7071067811865476),
T::from_f64(0.6087614290087207),
T::from_f64(0.5),
T::from_f64(0.38268343236508984),
T::from_f64(0.25881904510252074),
T::from_f64(0.1305261922200516),
T::ZERO,
T::from_f64(-0.1305261922200516),
T::from_f64(-0.25881904510252074),
T::from_f64(-0.38268343236508984),
T::from_f64(-0.5),
T::from_f64(-0.6087614290087207),
T::from_f64(-0.7071067811865476),
T::from_f64(-0.7933533402912352),
T::from_f64(-0.8660254037844387),
T::from_f64(-0.9238795325112867),
T::from_f64(-0.9659258262890683),
T::from_f64(-0.9914448613738104),
-T::ONE,
T::from_f64(-0.9914448613738104),
T::from_f64(-0.9659258262890683),
T::from_f64(-0.9238795325112867),
T::from_f64(-0.8660254037844387),
T::from_f64(-0.7933533402912352),
T::from_f64(-0.7071067811865476),
];
let sin_48: [T; 31] = [
T::ZERO,
T::from_f64(0.1305261922200516),
T::from_f64(0.25881904510252074),
T::from_f64(0.38268343236508984),
T::from_f64(0.5),
T::from_f64(0.6087614290087207),
T::from_f64(0.7071067811865476),
T::from_f64(0.7933533402912352),
T::from_f64(0.8660254037844387),
T::from_f64(0.9238795325112867),
T::from_f64(0.9659258262890683),
T::from_f64(0.9914448613738104),
T::ONE,
T::from_f64(0.9914448613738104),
T::from_f64(0.9659258262890683),
T::from_f64(0.9238795325112867),
T::from_f64(0.8660254037844387),
T::from_f64(0.7933533402912352),
T::from_f64(0.7071067811865476),
T::from_f64(0.6087614290087207),
T::from_f64(0.5),
T::from_f64(0.38268343236508984),
T::from_f64(0.25881904510252074),
T::from_f64(0.1305261922200516),
T::ZERO,
T::from_f64(-0.1305261922200516),
T::from_f64(-0.25881904510252074),
T::from_f64(-0.38268343236508984),
T::from_f64(-0.5),
T::from_f64(-0.6087614290087207),
T::from_f64(-0.7071067811865476),
];
for k2 in 1..16 {
let c = cos_48[k2];
let s = sin_48[k2];
let tw = t[16 + k2];
t[16 + k2] = Complex::new(
tw.re * c - sign_t * tw.im * s,
sign_t * tw.re * s + tw.im * c,
);
}
for k2 in 1..16 {
let c = cos_48[2 * k2];
let s = sin_48[2 * k2];
let tw = t[32 + k2];
t[32 + k2] = Complex::new(
tw.re * c - sign_t * tw.im * s,
sign_t * tw.re * s + tw.im * c,
);
}
for k1 in 0..3 {
let base = k1 * 16;
let mut row: [Complex<T>; 16] = [Complex::zero(); 16];
row.copy_from_slice(&t[base..base + 16]);
simd::notw_16_dispatch(&mut row, sign);
for k2 in 0..16 {
x[k2 * 3 + k1] = row[k2];
}
}
}