use crate::kernel::{Complex, Float};
#[inline(always)]
pub fn r2hc_2<T: Float>(x: &[T], y: &mut [Complex<T>]) {
debug_assert_eq!(x.len(), 2, "r2hc_2: input must have exactly 2 elements");
debug_assert!(y.len() >= 2, "r2hc_2: output must have at least 2 elements");
y[0] = Complex::new(x[0] + x[1], T::ZERO);
y[1] = Complex::new(x[0] - x[1], T::ZERO);
}
#[inline(always)]
pub fn r2hc_4<T: Float>(x: &[T], y: &mut [Complex<T>]) {
debug_assert_eq!(x.len(), 4, "r2hc_4: input must have exactly 4 elements");
debug_assert!(y.len() >= 3, "r2hc_4: output must have at least 3 elements");
let a = x[0] + x[2]; let b = x[0] - x[2]; let c = x[1] + x[3]; let d = x[1] - x[3];
y[0] = Complex::new(a + c, T::ZERO); y[1] = Complex::new(b, -d); y[2] = Complex::new(a - c, T::ZERO); }
#[inline(always)]
pub fn r2hc_8<T: Float>(x: &[T], y: &mut [Complex<T>]) {
debug_assert_eq!(x.len(), 8, "r2hc_8: input must have exactly 8 elements");
debug_assert!(y.len() >= 5, "r2hc_8: output must have at least 5 elements");
let ae = x[0] + x[4];
let be = x[0] - x[4];
let ce = x[2] + x[6];
let de = x[2] - x[6];
let e0 = Complex::new(ae + ce, T::ZERO);
let e1 = Complex::new(be, -de);
let e2 = Complex::new(ae - ce, T::ZERO);
let ao = x[1] + x[5];
let bo = x[1] - x[5];
let co = x[3] + x[7];
let do_ = x[3] - x[7];
let o0 = Complex::new(ao + co, T::ZERO);
let o1 = Complex::new(bo, -do_);
let o2 = Complex::new(ao - co, T::ZERO);
let sqrt2_inv = T::ONE / <T as Float>::sqrt(T::TWO);
let w8_1 = Complex::new(sqrt2_inv, -sqrt2_inv);
let w8_2 = Complex::new(T::ZERO, -T::ONE);
let w8_3 = Complex::new(-sqrt2_inv, -sqrt2_inv);
y[0] = e0 + o0;
y[1] = e1 + w8_1 * o1;
y[2] = e2 + w8_2 * o2;
y[3] = e1.conj() + w8_3 * o1.conj();
y[4] = e0 - o0;
}
#[inline(always)]
pub fn hc2r_2<T: Float>(y: &[Complex<T>], x: &mut [T]) {
debug_assert!(y.len() >= 2, "hc2r_2: input must have at least 2 elements");
debug_assert_eq!(x.len(), 2, "hc2r_2: output must have exactly 2 elements");
x[0] = y[0].re + y[1].re;
x[1] = y[0].re - y[1].re;
}
#[inline(always)]
pub fn hc2r_4<T: Float>(y: &[Complex<T>], x: &mut [T]) {
debug_assert!(y.len() >= 3, "hc2r_4: input must have at least 3 elements");
debug_assert_eq!(x.len(), 4, "hc2r_4: output must have exactly 4 elements");
let a = y[0].re + y[2].re;
let b = y[0].re - y[2].re;
let c = y[1].re + y[1].re; let d = y[1].im + y[1].im;
x[0] = a + c;
x[1] = b - d;
x[2] = a - c;
x[3] = b + d;
}
#[inline(always)]
pub fn hc2r_8<T: Float>(y: &[Complex<T>], x: &mut [T]) {
debug_assert!(y.len() >= 5, "hc2r_8: input must have at least 5 elements");
debug_assert_eq!(x.len(), 8, "hc2r_8: output must have exactly 8 elements");
let sqrt2_inv = T::ONE / <T as Float>::sqrt(T::TWO);
let w8_1c = Complex::new(sqrt2_inv, sqrt2_inv); let w8_2c = Complex::new(T::ZERO, T::ONE);
let e0 = Complex::new(y[0].re + y[4].re, T::ZERO);
let o0 = Complex::new(y[0].re - y[4].re, T::ZERO);
let y5 = y[3].conj();
let e1 = y[1] + y5;
let o1 = w8_1c * (y[1] - y5);
let y6 = y[2].conj();
let e2 = y[2] + y6;
let o2 = w8_2c * (y[2] - y6);
let ae = e0.re + e2.re;
let be = e0.re - e2.re;
let ce = e1.re + e1.re;
let de = e1.im + e1.im;
let xe0 = ae + ce;
let xe1 = be - de;
let xe2 = ae - ce;
let xe3 = be + de;
let ao = o0.re + o2.re;
let bo = o0.re - o2.re;
let co = o1.re + o1.re;
let do_ = o1.im + o1.im;
let xo0 = ao + co;
let xo1 = bo - do_;
let xo2 = ao - co;
let xo3 = bo + do_;
x[0] = xe0;
x[1] = xo0;
x[2] = xe1;
x[3] = xo1;
x[4] = xe2;
x[5] = xo2;
x[6] = xe3;
x[7] = xo3;
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
fn approx_eq_complex(a: Complex<f64>, b: Complex<f64>, tol: f64) -> bool {
approx_eq(a.re, b.re, tol) && approx_eq(a.im, b.im, tol)
}
fn naive_dft(x: &[f64]) -> Vec<Complex<f64>> {
let n = x.len();
(0..n)
.map(|k| {
x.iter().enumerate().fold(Complex::zero(), |acc, (j, &xj)| {
let angle = -2.0 * std::f64::consts::PI * (j * k) as f64 / n as f64;
let w = Complex::new(angle.cos(), angle.sin());
acc + Complex::new(xj, 0.0) * w
})
})
.collect()
}
fn naive_idft_unnorm(y: &[Complex<f64>], n: usize) -> Vec<f64> {
let mut full = vec![Complex::zero(); n];
full[..y.len()].copy_from_slice(y);
for k in y.len()..n {
full[k] = full[n - k].conj();
}
(0..n)
.map(|j| {
full.iter().enumerate().fold(0.0_f64, |acc, (k, &yk)| {
let angle = 2.0 * std::f64::consts::PI * (j * k) as f64 / n as f64;
acc + yk.re * angle.cos() - yk.im * angle.sin()
})
})
.collect()
}
#[test]
fn test_r2hc_2_dc_only() {
let x = [3.0_f64, 3.0];
let mut y = [Complex::<f64>::zero(); 2];
r2hc_2(&x, &mut y);
assert!(approx_eq(y[0].re, 6.0, 1e-14), "Y[0].re={}", y[0].re);
assert!(approx_eq(y[0].im, 0.0, 1e-14));
assert!(approx_eq(y[1].re, 0.0, 1e-14), "Y[1].re={}", y[1].re);
assert!(approx_eq(y[1].im, 0.0, 1e-14));
}
#[test]
fn test_r2hc_2_matches_dft() {
let x = [1.0_f64, 3.0];
let mut y = [Complex::<f64>::zero(); 2];
r2hc_2(&x, &mut y);
let ref_y = naive_dft(&x);
assert!(
approx_eq_complex(y[0], ref_y[0], 1e-12),
"Y[0]: {:?} vs {:?}",
y[0],
ref_y[0]
);
assert!(
approx_eq_complex(y[1], ref_y[1], 1e-12),
"Y[1]: {:?} vs {:?}",
y[1],
ref_y[1]
);
}
#[test]
fn test_hc2r_2_roundtrip() {
let original = [1.0_f64, 3.0];
let mut y = [Complex::<f64>::zero(); 2];
r2hc_2(&original, &mut y);
let mut recovered = [0.0_f64; 2];
hc2r_2(&y, &mut recovered);
for i in 0..2 {
assert!(
approx_eq(recovered[i] / 2.0, original[i], 1e-14),
"idx={i}: recovered/2={}, original={}",
recovered[i] / 2.0,
original[i]
);
}
}
#[test]
fn test_hc2r_2_against_naive_idft() {
let original = [2.5_f64, -1.0];
let mut y = [Complex::<f64>::zero(); 2];
r2hc_2(&original, &mut y);
let mut hc2r_out = [0.0_f64; 2];
hc2r_2(&y, &mut hc2r_out);
let naive_out = naive_idft_unnorm(&y, 2);
for i in 0..2 {
assert!(
approx_eq(hc2r_out[i], naive_out[i], 1e-12),
"idx={i}: hc2r={}, naive={}",
hc2r_out[i],
naive_out[i]
);
}
}
#[test]
fn test_r2hc_4_matches_dft() {
let x = [1.0_f64, 2.0, 3.0, 4.0];
let mut y = [Complex::<f64>::zero(); 3];
r2hc_4(&x, &mut y);
let ref_y = naive_dft(&x);
assert!(approx_eq(y[0].re, 10.0, 1e-12), "Y[0].re={}", y[0].re);
assert!(approx_eq(y[0].im, 0.0, 1e-14));
assert!(approx_eq(y[1].re, -2.0, 1e-12), "Y[1].re={}", y[1].re);
assert!(approx_eq(y[1].im, 2.0, 1e-12), "Y[1].im={}", y[1].im);
assert!(approx_eq(y[2].re, -2.0, 1e-12), "Y[2].re={}", y[2].re);
assert!(approx_eq(y[2].im, 0.0, 1e-14));
for k in 0..3 {
assert!(
approx_eq_complex(y[k], ref_y[k], 1e-12),
"k={k}: codelet={:?}, naive={:?}",
y[k],
ref_y[k]
);
}
}
#[test]
fn test_r2hc_4_dc_signal() {
let x = [5.0_f64, 5.0, 5.0, 5.0];
let mut y = [Complex::<f64>::zero(); 3];
r2hc_4(&x, &mut y);
assert!(approx_eq(y[0].re, 20.0, 1e-14));
assert!(approx_eq(y[0].im, 0.0, 1e-14));
assert!(approx_eq(y[1].re, 0.0, 1e-14));
assert!(approx_eq(y[1].im, 0.0, 1e-14));
assert!(approx_eq(y[2].re, 0.0, 1e-14));
assert!(approx_eq(y[2].im, 0.0, 1e-14));
}
#[test]
fn test_hc2r_4_roundtrip() {
let original = [1.0_f64, 2.0, 3.0, 4.0];
let mut y = [Complex::<f64>::zero(); 3];
r2hc_4(&original, &mut y);
let mut recovered = [0.0_f64; 4];
hc2r_4(&y, &mut recovered);
for i in 0..4 {
assert!(
approx_eq(recovered[i] / 4.0, original[i], 1e-12),
"idx={i}: recovered/4={}, original={}",
recovered[i] / 4.0,
original[i]
);
}
}
#[test]
fn test_hc2r_4_against_naive_idft() {
let original = [7.0_f64, -2.0, 3.5, 1.0];
let mut y = [Complex::<f64>::zero(); 3];
r2hc_4(&original, &mut y);
let mut hc2r_out = [0.0_f64; 4];
hc2r_4(&y, &mut hc2r_out);
let naive_out = naive_idft_unnorm(&y, 4);
for i in 0..4 {
assert!(
approx_eq(hc2r_out[i], naive_out[i], 1e-11),
"idx={i}: hc2r={}, naive={}",
hc2r_out[i],
naive_out[i]
);
}
}
#[test]
fn test_r2hc_8_dc_signal() {
let x = [1.0_f64; 8];
let mut y = [Complex::<f64>::zero(); 5];
r2hc_8(&x, &mut y);
assert!(approx_eq(y[0].re, 8.0, 1e-14), "Y[0].re={}", y[0].re);
assert!(approx_eq(y[0].im, 0.0, 1e-14));
for k in 1..5 {
assert!(approx_eq(y[k].re, 0.0, 1e-12), "Y[{k}].re={}", y[k].re);
assert!(approx_eq(y[k].im, 0.0, 1e-12), "Y[{k}].im={}", y[k].im);
}
}
#[test]
fn test_r2hc_8_matches_dft() {
let x: Vec<f64> = (0..8).map(|i| i as f64).collect();
let mut y = [Complex::<f64>::zero(); 5];
r2hc_8(&x, &mut y);
let ref_y = naive_dft(&x);
for k in 0..5 {
assert!(
approx_eq_complex(y[k], ref_y[k], 1e-11),
"k={k}: codelet={:?}, naive={:?}",
y[k],
ref_y[k]
);
}
}
#[test]
fn test_r2hc_8_matches_dft_varied_input() {
let x = [1.5_f64, -2.3, 0.7, 4.1, -1.0, 3.3, 2.2, -0.5];
let mut y = [Complex::<f64>::zero(); 5];
r2hc_8(&x, &mut y);
let ref_y = naive_dft(&x);
for k in 0..5 {
assert!(
approx_eq_complex(y[k], ref_y[k], 1e-11),
"k={k}: codelet={:?}, naive={:?}",
y[k],
ref_y[k]
);
}
}
#[test]
fn test_r2hc_8_dc_and_nyquist_purely_real() {
let x = [2.0_f64, -1.0, 0.5, 3.0, -2.5, 1.5, 0.0, -0.5];
let mut y = [Complex::<f64>::zero(); 5];
r2hc_8(&x, &mut y);
assert!(
approx_eq(y[0].im, 0.0, 1e-13),
"DC should be real, im={}",
y[0].im
);
assert!(
approx_eq(y[4].im, 0.0, 1e-13),
"Nyquist should be real, im={}",
y[4].im
);
}
#[test]
fn test_hc2r_8_roundtrip() {
let original = [1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut y = [Complex::<f64>::zero(); 5];
r2hc_8(&original, &mut y);
let mut recovered = [0.0_f64; 8];
hc2r_8(&y, &mut recovered);
for i in 0..8 {
assert!(
approx_eq(recovered[i] / 8.0, original[i], 1e-11),
"idx={i}: recovered/8={}, original={}",
recovered[i] / 8.0,
original[i]
);
}
}
#[test]
fn test_hc2r_8_roundtrip_varied() {
let original = [1.5_f64, -2.3, 0.7, 4.1, -1.0, 3.3, 2.2, -0.5];
let mut y = [Complex::<f64>::zero(); 5];
r2hc_8(&original, &mut y);
let mut recovered = [0.0_f64; 8];
hc2r_8(&y, &mut recovered);
for i in 0..8 {
assert!(
approx_eq(recovered[i] / 8.0, original[i], 1e-11),
"idx={i}: recovered/8={}, original={}",
recovered[i] / 8.0,
original[i]
);
}
}
#[test]
fn test_hc2r_8_against_naive_idft() {
let original: Vec<f64> = (0..8).map(|i| i as f64).collect();
let mut y = [Complex::<f64>::zero(); 5];
r2hc_8(&original, &mut y);
let mut hc2r_out = [0.0_f64; 8];
hc2r_8(&y, &mut hc2r_out);
let naive_out = naive_idft_unnorm(&y, 8);
for i in 0..8 {
assert!(
approx_eq(hc2r_out[i], naive_out[i], 1e-10),
"idx={i}: hc2r={}, naive={}",
hc2r_out[i],
naive_out[i]
);
}
}
#[test]
fn test_r2hc_2_f32() {
let x = [1.0_f32, 3.0];
let mut y = [Complex::<f32>::zero(); 2];
r2hc_2(&x, &mut y);
assert!((y[0].re - 4.0_f32).abs() < 1e-6, "Y[0].re={}", y[0].re);
assert!((y[1].re - (-2.0_f32)).abs() < 1e-6, "Y[1].re={}", y[1].re);
}
#[test]
fn test_r2hc_4_f32() {
let x = [1.0_f32, 2.0, 3.0, 4.0];
let mut y = [Complex::<f32>::zero(); 3];
r2hc_4(&x, &mut y);
assert!((y[0].re - 10.0_f32).abs() < 1e-5, "Y[0].re={}", y[0].re);
assert!((y[1].re - (-2.0_f32)).abs() < 1e-5, "Y[1].re={}", y[1].re);
assert!((y[1].im - 2.0_f32).abs() < 1e-5, "Y[1].im={}", y[1].im);
}
#[test]
fn test_r2hc_8_f32() {
let x: Vec<f32> = (0..8).map(|i| i as f32).collect();
let mut y = [Complex::<f32>::zero(); 5];
r2hc_8(&x, &mut y);
assert!((y[0].re - 28.0_f32).abs() < 1e-4, "Y[0].re={}", y[0].re);
assert!(y[0].im.abs() < 1e-5, "DC should be real, im={}", y[0].im);
}
}