use super::{CztError, CztPlan};
use crate::api::{Direction, Flags, Plan};
use crate::kernel::Complex;
fn naive_czt_f64(
x: &[Complex<f64>],
m: usize,
a: Complex<f64>,
w: Complex<f64>,
) -> Vec<Complex<f64>> {
let mut out = Vec::with_capacity(m);
for k in 0..m {
let mut acc = Complex::zero();
for (nn, &xn) in x.iter().enumerate() {
let nn_f = nn as f64;
let k_f = k as f64;
let a_pow = complex_pow_f64(a, -nn_f);
let w_pow = complex_pow_f64(w, nn_f * k_f);
acc = acc + xn * a_pow * w_pow;
}
out.push(acc);
}
out
}
fn complex_pow_f64(z: Complex<f64>, p: f64) -> Complex<f64> {
let r = z.norm();
let theta = f64::atan2(z.im, z.re);
let r_p = r.powf(p);
let angle = p * theta;
Complex::from_polar(r_p, angle)
}
fn max_err_f64(a: &[Complex<f64>], b: &[Complex<f64>]) -> f64 {
assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| {
let d = ai - bi;
f64::sqrt(d.re * d.re + d.im * d.im)
})
.fold(0.0_f64, f64::max)
}
fn max_err_f32(a: &[Complex<f32>], b: &[Complex<f32>]) -> f64 {
assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(&ai, &bi)| {
let d = ai - bi;
f64::from(f32::sqrt(d.re * d.re + d.im * d.im))
})
.fold(0.0_f64, f64::max)
}
#[test]
fn identity_czt_matches_dft_f64() {
for &n in &[16_usize, 64, 256] {
let two_pi_over_n = -2.0 * core::f64::consts::PI / n as f64;
let a = Complex::<f64>::one();
let w = Complex::from_polar(1.0_f64, two_pi_over_n);
let plan = CztPlan::<f64>::new(n, n, a, w).expect("CZT plan failed");
let dft_plan =
Plan::<f64>::dft_1d(n, Direction::Forward, Flags::ESTIMATE).expect("DFT plan failed");
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64 * 0.37).sin(), (i as f64 * 0.71).cos()))
.collect();
let mut czt_out = vec![Complex::zero(); n];
let mut dft_out = vec![Complex::zero(); n];
plan.execute(&input, &mut czt_out)
.expect("CZT execute failed");
dft_plan.execute(&input, &mut dft_out);
let err = max_err_f64(&czt_out, &dft_out);
assert!(
err < 1e-10,
"identity CZT vs DFT error {err:.2e} > 1e-10 for n={n}"
);
}
}
#[test]
fn zoom_fft_peak_detection() {
let n = 1024_usize;
let fs = 1000.0_f64;
let f_signal = 105.3_f64;
let m = 100_usize;
let f_start = 100.0_f64;
let f_stop = 110.0_f64;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = i as f64 / fs;
let re = (2.0 * core::f64::consts::PI * f_signal * t).cos();
Complex::new(re, 0.0)
})
.collect();
let plan = CztPlan::<f64>::zoom_fft(n, m, f_start, f_stop, fs).expect("zoom_fft plan failed");
let mut output = vec![Complex::zero(); m];
plan.execute(&input, &mut output)
.expect("zoom_fft execute failed");
let peak_bin = output
.iter()
.enumerate()
.max_by(|(_, ai), (_, bi)| {
let a_mag = ai.re * ai.re + ai.im * ai.im;
let b_mag = bi.re * bi.re + bi.im * bi.im;
a_mag
.partial_cmp(&b_mag)
.unwrap_or(core::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
let expected_bin = ((f_signal - f_start) / (f_stop - f_start) * m as f64).round() as usize;
let distance = (peak_bin as i64 - expected_bin as i64).unsigned_abs() as usize;
assert!(
distance <= 1,
"zoom_fft peak at bin {peak_bin}, expected ~{expected_bin} (±1)"
);
}
#[test]
fn off_unit_circle_matches_naive() {
let n = 8_usize;
let m = 8_usize;
let a = Complex::from_polar(1.05_f64, 0.0_f64);
let w = Complex::from_polar(0.99_f64, -2.0 * core::f64::consts::PI / 8.0);
let plan = CztPlan::<f64>::new(n, m, a, w).expect("off-unit plan failed");
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64 + 1.0) * 0.5, -(i as f64) * 0.3))
.collect();
let mut czt_out = vec![Complex::zero(); m];
plan.execute(&input, &mut czt_out)
.expect("off-unit execute failed");
let naive_out = naive_czt_f64(&input, m, a, w);
let err = max_err_f64(&czt_out, &naive_out);
assert!(err < 1e-11, "off-unit-circle error {err:.2e} > 1e-11");
}
#[test]
fn different_n_m_matches_naive() {
let n = 256_usize;
let m = 64_usize;
let two_pi_over_n = -2.0 * core::f64::consts::PI / n as f64;
let a = Complex::<f64>::one();
let w = Complex::from_polar(1.0_f64, two_pi_over_n);
let plan = CztPlan::<f64>::new(n, m, a, w).expect("N/M plan failed");
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64 * 0.13).sin(), (i as f64 * 0.29).cos()))
.collect();
let mut czt_out = vec![Complex::zero(); m];
plan.execute(&input, &mut czt_out)
.expect("N/M execute failed");
let naive_out = naive_czt_f64(&input, m, a, w);
let err = max_err_f64(&czt_out, &naive_out);
assert!(err < 1e-10, "N=256/M=64 error {err:.2e} > 1e-10");
}
#[test]
fn identity_czt_matches_dft_f32() {
let n = 64_usize;
let two_pi_over_n = -2.0 * core::f32::consts::PI / n as f32;
let a = Complex::<f32>::one();
let w = Complex::from_polar(1.0_f32, two_pi_over_n);
let plan = CztPlan::<f32>::new(n, n, a, w).expect("f32 CZT plan failed");
let dft_plan =
Plan::<f32>::dft_1d(n, Direction::Forward, Flags::ESTIMATE).expect("f32 DFT plan failed");
let input: Vec<Complex<f32>> = (0..n)
.map(|i| Complex::new((i as f32 * 0.37).sin(), (i as f32 * 0.71).cos()))
.collect();
let mut czt_out = vec![Complex::zero(); n];
let mut dft_out = vec![Complex::zero(); n];
plan.execute(&input, &mut czt_out)
.expect("f32 CZT execute failed");
dft_plan.execute(&input, &mut dft_out);
let err = max_err_f32(&czt_out, &dft_out);
assert!(err < 5e-3, "f32 identity CZT vs DFT error {err:.2e} > 5e-3");
}
#[test]
fn czt_error_paths() {
let w = Complex::from_polar(1.0_f64, -2.0 * core::f64::consts::PI / 8.0);
let a = Complex::<f64>::one();
assert!(matches!(
CztPlan::<f64>::new(0, 8, a, w),
Err(CztError::InvalidSize(0))
));
assert!(matches!(
CztPlan::<f64>::new(8, 0, a, w),
Err(CztError::InvalidSize(0))
));
let plan = CztPlan::<f64>::new(8, 8, a, w).expect("plan");
let bad_input = vec![Complex::<f64>::zero(); 4];
let mut out = vec![Complex::<f64>::zero(); 8];
assert!(matches!(
plan.execute(&bad_input, &mut out),
Err(CztError::MismatchedLength {
expected: 8,
actual: 4
})
));
let good_input = vec![Complex::<f64>::zero(); 8];
let mut bad_out = vec![Complex::<f64>::zero(); 4];
assert!(matches!(
plan.execute(&good_input, &mut bad_out),
Err(CztError::MismatchedLength {
expected: 8,
actual: 4
})
));
assert!(matches!(
CztPlan::<f64>::zoom_fft(64, 32, 200.0, 100.0, 1000.0),
Err(CztError::InvalidParameter)
));
}