#![cfg(test)]
use crate::kernel::Complex;
use oxifft_codegen::{gen_notw_codelet, gen_twiddle_codelet};
gen_notw_codelet!(2);
gen_notw_codelet!(4);
gen_notw_codelet!(8);
gen_notw_codelet!(16);
gen_notw_codelet!(32);
gen_notw_codelet!(64);
gen_twiddle_codelet!(2);
gen_twiddle_codelet!(4);
gen_twiddle_codelet!(8);
gen_twiddle_codelet!(16);
fn naive_dft(x: &[Complex<f64>], sign: i32) -> Vec<Complex<f64>> {
let n = x.len();
(0..n)
.map(|k| {
x.iter()
.enumerate()
.fold(Complex::new(0.0_f64, 0.0), |acc, (j, &xj)| {
let angle =
sign as f64 * 2.0 * core::f64::consts::PI * (k * j) as f64 / n as f64;
acc + xj * Complex::new(angle.cos(), angle.sin())
})
})
.collect()
}
fn approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
(a.re - b.re).abs() < eps && (a.im - b.im).abs() < eps
}
#[test]
fn test_codelet_notw_2_impulse() {
let mut x = [Complex::<f64>::new(1.0, 0.0), Complex::new(0.0, 0.0)];
codelet_notw_2(&mut x, -1);
assert!(approx_eq(x[0], Complex::new(1.0, 0.0), 1e-12));
assert!(approx_eq(x[1], Complex::new(1.0, 0.0), 1e-12));
}
#[test]
fn test_codelet_notw_2_dc() {
let mut x = [Complex::<f64>::new(1.0, 0.0), Complex::new(1.0, 0.0)];
codelet_notw_2(&mut x, -1);
assert!(approx_eq(x[0], Complex::new(2.0, 0.0), 1e-12));
assert!(approx_eq(x[1], Complex::new(0.0, 0.0), 1e-12));
}
#[test]
fn test_codelet_notw_2_matches_naive() {
let input: Vec<Complex<f64>> = (0..2)
.map(|i| Complex::new(f64::from(i as i32) * 1.5 + 0.3, f64::from(i as i32) * 0.7))
.collect();
let expected = naive_dft(&input, -1);
let mut actual = input;
codelet_notw_2(&mut actual, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-12),
"codelet_notw_2 index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_notw_2_roundtrip() {
let original: Vec<Complex<f64>> = (0..2)
.map(|i| Complex::new(f64::from(i as i32).sin(), f64::from(i as i32).cos()))
.collect();
let mut data = original.clone();
codelet_notw_2(&mut data, -1);
codelet_notw_2(&mut data, 1);
for x in &mut data {
*x = Complex::new(x.re / 2.0, x.im / 2.0);
}
for (i, (a, e)) in data.iter().zip(original.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-12),
"notw_2 roundtrip index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_notw_4_impulse() {
let mut x: Vec<Complex<f64>> = (0..4)
.map(|i| {
if i == 0 {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
}
})
.collect();
codelet_notw_4(&mut x, -1);
for (i, v) in x.iter().enumerate() {
assert!(
approx_eq(*v, Complex::new(1.0, 0.0), 1e-12),
"notw_4 impulse index {i}: {v:?}"
);
}
}
#[test]
fn test_codelet_notw_4_dc() {
let mut x: Vec<Complex<f64>> = (0..4).map(|_| Complex::new(1.0, 0.0)).collect();
codelet_notw_4(&mut x, -1);
assert!(approx_eq(x[0], Complex::new(4.0, 0.0), 1e-12));
for i in 1..4 {
assert!(
approx_eq(x[i], Complex::new(0.0, 0.0), 1e-12),
"notw_4 DC index {i}: {:?}",
x[i]
);
}
}
#[test]
fn test_codelet_notw_4_matches_naive() {
let input: Vec<Complex<f64>> = (0..4)
.map(|i| Complex::new(f64::from(i as i32) * 1.3, f64::from(i as i32) * 0.9))
.collect();
let expected = naive_dft(&input, -1);
let mut actual = input;
codelet_notw_4(&mut actual, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(approx_eq(*a, *e, 1e-11), "notw_4 index {i}: {a:?} != {e:?}");
}
}
#[test]
fn test_codelet_notw_4_roundtrip() {
let original: Vec<Complex<f64>> = (0..4)
.map(|i| Complex::new(f64::from(i as i32).sin(), f64::from(i as i32).cos()))
.collect();
let mut data = original.clone();
codelet_notw_4(&mut data, -1);
codelet_notw_4(&mut data, 1);
for x in &mut data {
*x = Complex::new(x.re / 4.0, x.im / 4.0);
}
for (i, (a, e)) in data.iter().zip(original.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-12),
"notw_4 roundtrip index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_notw_8_impulse() {
let mut x: Vec<Complex<f64>> = (0..8)
.map(|i| {
if i == 0 {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
}
})
.collect();
codelet_notw_8(&mut x, -1);
for (i, v) in x.iter().enumerate() {
assert!(
approx_eq(*v, Complex::new(1.0, 0.0), 1e-12),
"notw_8 impulse index {i}: {v:?}"
);
}
}
#[test]
fn test_codelet_notw_8_dc() {
let mut x: Vec<Complex<f64>> = (0..8).map(|_| Complex::new(1.0, 0.0)).collect();
codelet_notw_8(&mut x, -1);
assert!(approx_eq(x[0], Complex::new(8.0, 0.0), 1e-12));
for i in 1..8 {
assert!(
approx_eq(x[i], Complex::new(0.0, 0.0), 1e-11),
"notw_8 DC index {i}: {:?}",
x[i]
);
}
}
#[test]
fn test_codelet_notw_8_matches_naive() {
let input: Vec<Complex<f64>> = (0..8)
.map(|i| {
Complex::new(
f64::from(i as i32).sin() * 2.0 + 0.5,
f64::from(i as i32).cos(),
)
})
.collect();
let expected = naive_dft(&input, -1);
let mut actual = input;
codelet_notw_8(&mut actual, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(approx_eq(*a, *e, 1e-10), "notw_8 index {i}: {a:?} != {e:?}");
}
}
#[test]
fn test_codelet_notw_8_roundtrip() {
let original: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i as i32).sin(), f64::from(i as i32).cos()))
.collect();
let mut data = original.clone();
codelet_notw_8(&mut data, -1);
codelet_notw_8(&mut data, 1);
for x in &mut data {
*x = Complex::new(x.re / 8.0, x.im / 8.0);
}
for (i, (a, e)) in data.iter().zip(original.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-11),
"notw_8 roundtrip index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_notw_16_impulse() {
let mut x: Vec<Complex<f64>> = (0..16)
.map(|i| {
if i == 0 {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
}
})
.collect();
codelet_notw_16(&mut x, -1);
for (i, v) in x.iter().enumerate() {
assert!(
approx_eq(*v, Complex::new(1.0, 0.0), 1e-11),
"notw_16 impulse index {i}: {v:?}"
);
}
}
#[test]
fn test_codelet_notw_16_dc() {
let mut x: Vec<Complex<f64>> = (0..16).map(|_| Complex::new(1.0, 0.0)).collect();
codelet_notw_16(&mut x, -1);
assert!(approx_eq(x[0], Complex::new(16.0, 0.0), 1e-11));
for i in 1..16 {
assert!(
approx_eq(x[i], Complex::new(0.0, 0.0), 1e-10),
"notw_16 DC index {i}: {:?}",
x[i]
);
}
}
#[test]
fn test_codelet_notw_16_matches_naive() {
let input: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i as i32) * 0.5, f64::from(i as i32) * 0.3 - 1.0))
.collect();
let expected = naive_dft(&input, -1);
let mut actual = input;
codelet_notw_16(&mut actual, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(approx_eq(*a, *e, 1e-9), "notw_16 index {i}: {a:?} != {e:?}");
}
}
#[test]
fn test_codelet_notw_16_roundtrip() {
let original: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i as i32).sin(), f64::from(i as i32).cos()))
.collect();
let mut data = original.clone();
codelet_notw_16(&mut data, -1);
codelet_notw_16(&mut data, 1);
for x in &mut data {
*x = Complex::new(x.re / 16.0, x.im / 16.0);
}
for (i, (a, e)) in data.iter().zip(original.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-10),
"notw_16 roundtrip index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_notw_32_impulse() {
let mut x: Vec<Complex<f64>> = (0..32)
.map(|i| {
if i == 0 {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
}
})
.collect();
codelet_notw_32(&mut x, -1);
for (i, v) in x.iter().enumerate() {
assert!(
approx_eq(*v, Complex::new(1.0, 0.0), 1e-10),
"notw_32 impulse index {i}: {v:?}"
);
}
}
#[test]
fn test_codelet_notw_32_dc() {
let mut x: Vec<Complex<f64>> = (0..32).map(|_| Complex::new(1.0, 0.0)).collect();
codelet_notw_32(&mut x, -1);
assert!(approx_eq(x[0], Complex::new(32.0, 0.0), 1e-10));
for i in 1..32 {
assert!(
approx_eq(x[i], Complex::new(0.0, 0.0), 1e-9),
"notw_32 DC index {i}: {:?}",
x[i]
);
}
}
#[test]
fn test_codelet_notw_32_matches_naive() {
let input: Vec<Complex<f64>> = (0..32)
.map(|i| Complex::new(f64::from(i as i32).sin(), f64::from(i as i32).cos() * 0.5))
.collect();
let expected = naive_dft(&input, -1);
let mut actual = input;
codelet_notw_32(&mut actual, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(approx_eq(*a, *e, 1e-8), "notw_32 index {i}: {a:?} != {e:?}");
}
}
#[test]
fn test_codelet_notw_64_impulse() {
let mut x: Vec<Complex<f64>> = (0..64)
.map(|i| {
if i == 0 {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
}
})
.collect();
codelet_notw_64(&mut x, -1);
for (i, v) in x.iter().enumerate() {
assert!(
approx_eq(*v, Complex::new(1.0, 0.0), 1e-9),
"notw_64 impulse index {i}: {v:?}"
);
}
}
#[test]
fn test_codelet_notw_64_dc() {
let mut x: Vec<Complex<f64>> = (0..64).map(|_| Complex::new(1.0, 0.0)).collect();
codelet_notw_64(&mut x, -1);
assert!(approx_eq(x[0], Complex::new(64.0, 0.0), 1e-9));
for i in 1..64 {
assert!(
approx_eq(x[i], Complex::new(0.0, 0.0), 1e-8),
"notw_64 DC index {i}: {:?}",
x[i]
);
}
}
#[test]
fn test_codelet_notw_64_matches_naive() {
let input: Vec<Complex<f64>> = (0..64)
.map(|i| {
Complex::new(
f64::from(i as i32).sin() * 0.3 + 1.0,
f64::from(i as i32).cos() * 0.5,
)
})
.collect();
let expected = naive_dft(&input, -1);
let mut actual = input;
codelet_notw_64(&mut actual, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(approx_eq(*a, *e, 1e-7), "notw_64 index {i}: {a:?} != {e:?}");
}
}
#[test]
fn test_codelet_twiddle_2_identity() {
let input = [Complex::new(1.0_f64, 2.0), Complex::new(3.0, 4.0)];
let mut tw = input;
let mut notw = input;
codelet_notw_2(&mut notw, -1);
codelet_twiddle_2(&mut tw, Complex::new(1.0, 0.0));
assert!(approx_eq(tw[0], notw[0], 1e-12));
assert!(approx_eq(tw[1], notw[1], 1e-12));
}
#[test]
fn test_codelet_twiddle_2_explicit() {
let a = Complex::new(1.0_f64, 2.0);
let b = Complex::new(3.0, 4.0);
let t = Complex::new(0.0_f64, 1.0); let mut x = [a, b];
codelet_twiddle_2(&mut x, t);
let bt = b * t;
assert!(approx_eq(x[0], a + bt, 1e-12));
assert!(approx_eq(x[1], a - bt, 1e-12));
}
#[test]
fn test_codelet_twiddle_4_identity() {
let input: Vec<Complex<f64>> = (0..4)
.map(|i| Complex::new(f64::from(i as i32) * 1.7, f64::from(i as i32) * 0.5 + 0.2))
.collect();
let mut tw = input.clone();
let mut notw = input;
codelet_notw_4(&mut notw, -1);
codelet_twiddle_4(
&mut tw,
Complex::new(1.0, 0.0),
Complex::new(1.0, 0.0),
Complex::new(1.0, 0.0),
-1,
);
for (i, (t, n)) in tw.iter().zip(notw.iter()).enumerate() {
assert!(
approx_eq(*t, *n, 1e-12),
"twiddle_4 identity index {i}: {t:?} != {n:?}"
);
}
}
#[test]
fn test_codelet_twiddle_4_vs_naive() {
let tw1 = Complex::new(0.0, -1.0_f64); let tw2 = Complex::new(-1.0, 0.0_f64); let tw3 = Complex::new(0.0, 1.0_f64); let input: Vec<Complex<f64>> = (0..4)
.map(|i| Complex::new(f64::from(i as i32) + 0.5, 1.0 - f64::from(i as i32) * 0.3))
.collect();
let tweaked: Vec<Complex<f64>> = vec![input[0], input[1] * tw1, input[2] * tw2, input[3] * tw3];
let expected = naive_dft(&tweaked, -1);
let mut actual = input;
codelet_twiddle_4(&mut actual, tw1, tw2, tw3, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-11),
"twiddle_4 index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_twiddle_8_identity() {
let input: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i as i32).sin() + 0.1, f64::from(i as i32).cos()))
.collect();
let mut tw = input.clone();
let mut notw = input;
codelet_notw_8(&mut notw, -1);
let identity = [Complex::new(1.0_f64, 0.0); 7];
codelet_twiddle_8(&mut tw, &identity, -1);
for (i, (t, n)) in tw.iter().zip(notw.iter()).enumerate() {
assert!(
approx_eq(*t, *n, 1e-11),
"twiddle_8 identity index {i}: {t:?} != {n:?}"
);
}
}
#[test]
fn test_codelet_twiddle_8_vs_naive() {
let twiddles: [Complex<f64>; 7] = {
let mut arr = [Complex::new(0.0_f64, 0.0); 7];
for (k, item) in arr.iter_mut().enumerate() {
let angle = -2.0 * core::f64::consts::PI * ((k + 1) as f64) / 8.0;
*item = Complex::new(angle.cos(), angle.sin());
}
arr
};
let input: Vec<Complex<f64>> = (0..8)
.map(|i| Complex::new(f64::from(i as i32) * 0.7, f64::from(i as i32) * 0.3 + 0.5))
.collect();
let mut tweaked = input.clone();
for k in 0..7 {
tweaked[k + 1] = input[k + 1] * twiddles[k];
}
let expected = naive_dft(&tweaked, -1);
let mut actual = input;
codelet_twiddle_8(&mut actual, &twiddles, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-10),
"twiddle_8 index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_twiddle_16_impulse_identity() {
let mut x: Vec<Complex<f64>> = (0..16)
.map(|i| {
if i == 0 {
Complex::new(1.0, 0.0)
} else {
Complex::new(0.0, 0.0)
}
})
.collect();
let identity = [Complex::new(1.0_f64, 0.0); 15];
codelet_twiddle_16(&mut x, &identity, -1);
for (i, v) in x.iter().enumerate() {
assert!(
approx_eq(*v, Complex::new(1.0, 0.0), 1e-11),
"twiddle_16 impulse+identity index {i}: {v:?}"
);
}
}
#[test]
fn test_codelet_twiddle_16_dc_identity() {
let mut x: Vec<Complex<f64>> = (0..16).map(|_| Complex::new(1.0, 0.0)).collect();
let identity = [Complex::new(1.0_f64, 0.0); 15];
codelet_twiddle_16(&mut x, &identity, -1);
assert!(
approx_eq(x[0], Complex::new(16.0, 0.0), 1e-11),
"DC[0] expected 16, got {:?}",
x[0]
);
for i in 1..16 {
assert!(
approx_eq(x[i], Complex::new(0.0, 0.0), 1e-10),
"twiddle_16 DC index {i}: {:?}",
x[i]
);
}
}
#[test]
fn test_codelet_twiddle_16_identity_matches_notw_16() {
let input: Vec<Complex<f64>> = (0..16)
.map(|i| {
Complex::new(
f64::from(i as i32).sin() + 0.5,
f64::from(i as i32).cos() * 0.8,
)
})
.collect();
let mut tw = input.clone();
let mut notw = input;
codelet_notw_16(&mut notw, -1);
let identity = [Complex::new(1.0_f64, 0.0); 15];
codelet_twiddle_16(&mut tw, &identity, -1);
for (i, (t, n)) in tw.iter().zip(notw.iter()).enumerate() {
assert!(
approx_eq(*t, *n, 1e-10),
"twiddle_16 identity vs notw_16 index {i}: {t:?} != {n:?}"
);
}
}
#[test]
fn test_codelet_twiddle_16_vs_naive() {
let twiddles: [Complex<f64>; 15] = {
let mut arr = [Complex::new(0.0_f64, 0.0); 15];
for (k, item) in arr.iter_mut().enumerate() {
let angle = -2.0 * core::f64::consts::PI * ((k + 1) as f64) / 16.0;
*item = Complex::new(angle.cos(), angle.sin());
}
arr
};
let input: Vec<Complex<f64>> = (0..16)
.map(|i| {
Complex::new(
f64::from(i as i32) * 0.4 + 0.1,
1.0 - f64::from(i as i32) * 0.05,
)
})
.collect();
let mut tweaked = input.clone();
for k in 0..15 {
tweaked[k + 1] = input[k + 1] * twiddles[k];
}
let expected = naive_dft(&tweaked, -1);
let mut actual = input;
codelet_twiddle_16(&mut actual, &twiddles, -1);
for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-9),
"twiddle_16 vs naive index {i}: {a:?} != {e:?}"
);
}
}
#[test]
fn test_codelet_twiddle_16_roundtrip_identity() {
let original: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i as i32).sin(), f64::from(i as i32).cos()))
.collect();
let identity = [Complex::new(1.0_f64, 0.0); 15];
let mut data = original.clone();
codelet_twiddle_16(&mut data, &identity, -1); codelet_twiddle_16(&mut data, &identity, 1); for x in &mut data {
*x = Complex::new(x.re / 16.0, x.im / 16.0);
}
for (i, (a, e)) in data.iter().zip(original.iter()).enumerate() {
assert!(
approx_eq(*a, *e, 1e-10),
"twiddle_16 roundtrip index {i}: {a:?} != {e:?}"
);
}
}