use crate::{next_fast_len, FFTError, FFTResult};
use scirs2_core::ndarray::{Array1, Array2, Zip};
use scirs2_core::numeric::Complex;
use std::f64::consts::PI;
#[derive(Clone, Debug)]
pub struct SpiralContour {
pub a: Complex<f64>,
pub w: Complex<f64>,
pub m: usize,
}
impl SpiralContour {
pub fn unit_circle(m: usize) -> FFTResult<Self> {
if m == 0 {
return Err(FFTError::ValueError(
"Number of output points must be positive".to_string(),
));
}
let w = Complex::from_polar(1.0, -2.0 * PI / m as f64);
Ok(SpiralContour {
a: Complex::new(1.0, 0.0),
w,
m,
})
}
pub fn zoom_range(m: usize, f0: f64, f1: f64, n: usize) -> FFTResult<Self> {
if m == 0 {
return Err(FFTError::ValueError(
"Number of output points must be positive".to_string(),
));
}
if f0 < 0.0 || f1 > 1.0 || f0 >= f1 {
return Err(FFTError::ValueError(
"Frequencies must satisfy 0 <= f0 < f1 <= 1".to_string(),
));
}
let phi_start = 2.0 * PI * f0;
let phi_end = 2.0 * PI * f1;
let a = Complex::from_polar(1.0, phi_start);
let step = if m > 1 {
(phi_end - phi_start) / (m - 1) as f64
} else {
0.0
};
let w = Complex::from_polar(1.0, -step);
Ok(SpiralContour { a, w, m })
}
pub fn log_spiral(m: usize, r0: f64, rho: f64, theta0: f64, dtheta: f64) -> FFTResult<Self> {
if m == 0 {
return Err(FFTError::ValueError(
"Number of output points must be positive".to_string(),
));
}
if r0 <= 0.0 {
return Err(FFTError::ValueError(
"Starting radius must be positive".to_string(),
));
}
let a = Complex::from_polar(r0, theta0);
let w = Complex::from_polar(1.0 / rho, -dtheta);
Ok(SpiralContour { a, w, m })
}
pub fn points(&self) -> Array1<Complex<f64>> {
(0..self.m)
.map(|k| self.a * self.w.powf(-(k as f64)))
.collect()
}
}
#[derive(Clone)]
pub struct EnhancedCZT {
n: usize,
contour: SpiralContour,
nfft: usize,
awk2: Array1<Complex<f64>>,
fwk2: Array1<Complex<f64>>,
wk2: Array1<Complex<f64>>,
}
impl EnhancedCZT {
pub fn new(n: usize, contour: SpiralContour) -> FFTResult<Self> {
if n == 0 {
return Err(FFTError::ValueError(
"Input length must be positive".to_string(),
));
}
let m = contour.m;
let a = contour.a;
let w = contour.w;
let max_size = n.max(m);
let nfft = next_fast_len(n + m - 1, false);
let wk2_full: Array1<Complex<f64>> = (0..max_size)
.map(|k| w.powf(k as f64 * k as f64 / 2.0))
.collect();
let awk2: Array1<Complex<f64>> =
(0..n).map(|k| a.powf(-(k as f64)) * wk2_full[k]).collect();
let mut chirp_vec = vec![Complex::new(0.0, 0.0); nfft];
for i in 0..m {
chirp_vec[n - 1 + i] = Complex::new(1.0, 0.0) / wk2_full[i];
}
for i in 1..n {
chirp_vec[n - 1 - i] = Complex::new(1.0, 0.0) / wk2_full[i];
}
let fwk2_vec = crate::fft::fft(&chirp_vec, None)?;
let fwk2 = Array1::from_vec(fwk2_vec);
let wk2: Array1<Complex<f64>> = wk2_full.slice(scirs2_core::ndarray::s![..m]).to_owned();
Ok(EnhancedCZT {
n,
contour,
nfft,
awk2,
fwk2,
wk2,
})
}
pub fn transform(&self, x: &[Complex<f64>]) -> FFTResult<Array1<Complex<f64>>> {
if x.len() != self.n {
return Err(FFTError::ValueError(format!(
"Input length ({}) does not match CZT engine size ({})",
x.len(),
self.n
)));
}
let x_arr = Array1::from_vec(x.to_vec());
let x_weighted: Array1<Complex<f64>> = Zip::from(&x_arr)
.and(&self.awk2)
.map_collect(|&xi, &awki| xi * awki);
let mut padded = vec![Complex::new(0.0, 0.0); self.nfft];
for (i, &val) in x_weighted.iter().enumerate() {
padded[i] = val;
}
let x_fft_vec = crate::fft::fft(&padded, None)?;
let x_fft = Array1::from_vec(x_fft_vec);
let product: Array1<Complex<f64>> = Zip::from(&x_fft)
.and(&self.fwk2)
.map_collect(|&xi, &fi| xi * fi);
let y_full_vec = crate::fft::ifft(&product.to_vec(), None)?;
let y_full = Array1::from_vec(y_full_vec);
let m = self.contour.m;
let y_slice = y_full.slice(scirs2_core::ndarray::s![self.n - 1..self.n - 1 + m]);
let result: Array1<Complex<f64>> = Zip::from(&y_slice)
.and(&self.wk2)
.map_collect(|&yi, &wki| yi * wki);
Ok(result)
}
pub fn transform_real(&self, x: &[f64]) -> FFTResult<Array1<Complex<f64>>> {
let x_complex: Vec<Complex<f64>> = x.iter().map(|&v| Complex::new(v, 0.0)).collect();
self.transform(&x_complex)
}
pub fn transform_batch(
&self,
signals: &Array2<Complex<f64>>,
) -> FFTResult<Array2<Complex<f64>>> {
let (num_signals, signal_len) = signals.dim();
if signal_len != self.n {
return Err(FFTError::ValueError(format!(
"Signal length ({signal_len}) does not match CZT engine size ({})",
self.n
)));
}
let m = self.contour.m;
let mut results = Array2::zeros((num_signals, m));
for i in 0..num_signals {
let row = signals.row(i);
let row_vec: Vec<Complex<f64>> = row.iter().copied().collect();
let transformed = self.transform(&row_vec)?;
for (j, &val) in transformed.iter().enumerate() {
results[[i, j]] = val;
}
}
Ok(results)
}
pub fn points(&self) -> Array1<Complex<f64>> {
self.contour.points()
}
pub fn contour(&self) -> &SpiralContour {
&self.contour
}
}
pub fn iczt(
czt_values: &[Complex<f64>],
n: usize,
contour: &SpiralContour,
) -> FFTResult<Array1<Complex<f64>>> {
let m = czt_values.len();
if m < n {
return Err(FFTError::ValueError(format!(
"Need at least {n} CZT values to reconstruct {n}-point signal, got {m}"
)));
}
let z_points = contour.points();
let mut v_mat = Array2::zeros((m, n));
for k in 0..m {
let z_k = z_points[k];
let mut z_power = Complex::new(1.0, 0.0);
for j in 0..n {
v_mat[[k, j]] = z_power;
z_power = z_power / z_k; }
}
let mut vhb = Array1::zeros(n);
for j in 0..n {
let mut sum = Complex::new(0.0, 0.0);
for k in 0..m {
sum += v_mat[[k, j]].conj() * czt_values[k];
}
vhb[j] = sum;
}
let mut vhv = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let mut sum = Complex::new(0.0, 0.0);
for k in 0..m {
sum += v_mat[[k, i]].conj() * v_mat[[k, j]];
}
vhv[[i, j]] = sum;
}
}
solve_complex_system(&vhv, &vhb)
}
fn solve_complex_system(
a: &Array2<Complex<f64>>,
b: &Array1<Complex<f64>>,
) -> FFTResult<Array1<Complex<f64>>> {
let n = b.len();
let mut augmented = Array2::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
augmented[[i, j]] = a[[i, j]];
}
augmented[[i, n]] = b[i];
}
for col in 0..n {
let mut max_val = augmented[[col, col]].norm();
let mut max_row = col;
for row in (col + 1)..n {
let val = augmented[[row, col]].norm();
if val > max_val {
max_val = val;
max_row = row;
}
}
if max_val < 1e-14 {
return Err(FFTError::ComputationError(
"Singular or near-singular system in ICZT".to_string(),
));
}
if max_row != col {
for j in 0..=n {
let tmp = augmented[[col, j]];
augmented[[col, j]] = augmented[[max_row, j]];
augmented[[max_row, j]] = tmp;
}
}
let pivot = augmented[[col, col]];
for row in (col + 1)..n {
let factor = augmented[[row, col]] / pivot;
for j in col..=n {
let val = augmented[[col, j]];
augmented[[row, j]] = augmented[[row, j]] - factor * val;
}
}
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = augmented[[i, n]];
for j in (i + 1)..n {
sum = sum - augmented[[i, j]] * x[j];
}
x[i] = sum / augmented[[i, i]];
}
Ok(x)
}
pub fn czt_convolve(a: &[f64], b: &[f64]) -> FFTResult<Vec<f64>> {
if a.is_empty() || b.is_empty() {
return Err(FFTError::ValueError(
"Input sequences cannot be empty".to_string(),
));
}
let conv_len = a.len() + b.len() - 1;
let nfft = next_fast_len(conv_len, false);
let mut a_padded: Vec<Complex<f64>> = a.iter().map(|&v| Complex::new(v, 0.0)).collect();
a_padded.resize(nfft, Complex::new(0.0, 0.0));
let mut b_padded: Vec<Complex<f64>> = b.iter().map(|&v| Complex::new(v, 0.0)).collect();
b_padded.resize(nfft, Complex::new(0.0, 0.0));
let a_fft = crate::fft::fft(&a_padded, None)?;
let b_fft = crate::fft::fft(&b_padded, None)?;
let product: Vec<Complex<f64>> = a_fft
.iter()
.zip(b_fft.iter())
.map(|(&ai, &bi)| ai * bi)
.collect();
let result_complex = crate::fft::ifft(&product, None)?;
Ok(result_complex.iter().take(conv_len).map(|c| c.re).collect())
}
pub fn adaptive_zoom_fft(
x: &[f64],
f0: f64,
f1: f64,
min_points: usize,
max_points: usize,
) -> FFTResult<(Vec<f64>, Array1<Complex<f64>>)> {
if x.is_empty() {
return Err(FFTError::ValueError("Input signal is empty".to_string()));
}
if f0 < 0.0 || f1 > 1.0 || f0 >= f1 {
return Err(FFTError::ValueError(
"Frequency range must satisfy 0 <= f0 < f1 <= 1".to_string(),
));
}
if min_points == 0 || max_points < min_points {
return Err(FFTError::ValueError(
"Point count must satisfy 0 < min_points <= max_points".to_string(),
));
}
let n = x.len();
let freq_range = f1 - f0;
let rayleigh_resolution = 1.0 / n as f64;
let ideal_points = (freq_range / rayleigh_resolution).ceil() as usize;
let m = ideal_points.clamp(min_points, max_points);
let contour = SpiralContour::zoom_range(m, f0, f1, n)?;
let engine = EnhancedCZT::new(n, contour)?;
let spectrum = engine.transform_real(x)?;
let frequencies: Vec<f64> = (0..m)
.map(|k| {
if m > 1 {
f0 + k as f64 * (f1 - f0) / (m - 1) as f64
} else {
f0
}
})
.collect();
Ok((frequencies, spectrum))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_unit_circle_contour() {
let contour = SpiralContour::unit_circle(8).expect("Unit circle contour should succeed");
let pts = contour.points();
assert_eq!(pts.len(), 8);
for p in pts.iter() {
assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_zoom_range_contour() {
let contour =
SpiralContour::zoom_range(16, 0.1, 0.3, 64).expect("Zoom range contour should succeed");
let pts = contour.points();
assert_eq!(pts.len(), 16);
for p in pts.iter() {
assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
}
}
#[test]
fn test_log_spiral_contour() {
let contour =
SpiralContour::log_spiral(10, 1.0, 0.95, 0.0, 0.1).expect("Log spiral should succeed");
let pts = contour.points();
assert_eq!(pts.len(), 10);
assert_abs_diff_eq!(pts[0].re, 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(pts[0].im, 0.0, epsilon = 1e-10);
for k in 1..10 {
let expected_r = 0.95_f64.powi(k as i32);
assert_abs_diff_eq!(pts[k].norm(), expected_r, epsilon = 1e-8);
}
}
#[test]
fn test_enhanced_czt_matches_fft() {
let n = 16;
let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
let engine = EnhancedCZT::new(n, contour).expect("Engine creation should succeed");
let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
let czt_result = engine.transform(&x).expect("Transform should succeed");
let fft_result_vec = crate::fft::fft(&x, None).expect("FFT should succeed");
let fft_result = Array1::from_vec(fft_result_vec);
for i in 0..n {
assert_abs_diff_eq!(czt_result[i].re, fft_result[i].re, epsilon = 1e-8);
assert_abs_diff_eq!(czt_result[i].im, fft_result[i].im, epsilon = 1e-8);
}
}
#[test]
fn test_enhanced_czt_real_input() {
let n = 8;
let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
let x: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let result = engine
.transform_real(&x)
.expect("Real transform should succeed");
let expected_dc: f64 = x.iter().sum();
assert_abs_diff_eq!(result[0].re, expected_dc, epsilon = 1e-8);
}
#[test]
fn test_batch_czt() {
let n = 8;
let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
let mut signals = Array2::zeros((3, n));
for i in 0..3 {
for j in 0..n {
signals[[i, j]] = Complex::new((i * n + j) as f64, 0.0);
}
}
let results = engine
.transform_batch(&signals)
.expect("Batch transform should succeed");
assert_eq!(results.dim(), (3, n));
for i in 0..3 {
let row_vec: Vec<Complex<f64>> = signals.row(i).iter().copied().collect();
let individual = engine
.transform(&row_vec)
.expect("Individual transform should succeed");
for j in 0..n {
assert_abs_diff_eq!(results[[i, j]].re, individual[j].re, epsilon = 1e-8);
assert_abs_diff_eq!(results[[i, j]].im, individual[j].im, epsilon = 1e-8);
}
}
}
#[test]
fn test_iczt_roundtrip() {
let n = 8;
let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
let engine = EnhancedCZT::new(n, contour.clone()).expect("Engine should succeed");
let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64 + 1.0, 0.0)).collect();
let czt_values = engine.transform(&x).expect("Forward CZT should succeed");
let czt_vec: Vec<Complex<f64>> = czt_values.iter().copied().collect();
let recovered = iczt(&czt_vec, n, &contour).expect("ICZT should succeed");
for i in 0..n {
assert_abs_diff_eq!(recovered[i].re, x[i].re, epsilon = 1e-6);
assert_abs_diff_eq!(recovered[i].im, x[i].im, epsilon = 1e-6);
}
}
#[test]
fn test_czt_convolve() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0];
let result = czt_convolve(&a, &b).expect("Convolution should succeed");
assert_eq!(result.len(), 4);
let expected = [4.0, 13.0, 22.0, 15.0];
for (i, (&r, &e)) in result.iter().zip(expected.iter()).enumerate() {
assert_abs_diff_eq!(r, e, epsilon = 1e-8,);
}
}
#[test]
fn test_czt_convolve_identity() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let delta = vec![1.0];
let result = czt_convolve(&signal, &delta).expect("Identity convolution should succeed");
assert_eq!(result.len(), signal.len());
for (i, (&r, &s)) in result.iter().zip(signal.iter()).enumerate() {
assert_abs_diff_eq!(r, s, epsilon = 1e-10);
}
}
#[test]
fn test_adaptive_zoom_fft() {
let n = 256;
let freq = 0.15; let x: Vec<f64> = (0..n).map(|i| (2.0 * PI * freq * i as f64).sin()).collect();
let (frequencies, spectrum) =
adaptive_zoom_fft(&x, 0.1, 0.2, 16, 128).expect("Adaptive zoom FFT should succeed");
assert_eq!(frequencies.len(), spectrum.len());
assert!(frequencies.len() >= 16);
assert!(frequencies.len() <= 128);
let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
let peak_idx = magnitudes
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
let peak_freq = frequencies[peak_idx];
assert!(
(peak_freq - freq).abs() < 0.02,
"Peak at {peak_freq:.4} should be near {freq:.4}"
);
}
#[test]
fn test_parseval_theorem_czt() {
let n = 16;
let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
let x: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((2.0 * PI * 3.0 * i as f64 / n as f64).sin(), 0.0))
.collect();
let czt_result = engine.transform(&x).expect("Transform should succeed");
let input_energy: f64 = x.iter().map(|c| c.norm_sqr()).sum();
let output_energy: f64 = czt_result.iter().map(|c| c.norm_sqr()).sum::<f64>() / n as f64;
assert_abs_diff_eq!(input_energy, output_energy, epsilon = 1e-8);
}
#[test]
fn test_czt_prime_length() {
let n = 13;
let contour = SpiralContour::unit_circle(n).expect("Contour should succeed");
let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
let x: Vec<Complex<f64>> = (0..n).map(|i| Complex::new(i as f64, 0.0)).collect();
let result = engine
.transform(&x)
.expect("Prime-length CZT should succeed");
assert_eq!(result.len(), n);
let expected_dc: f64 = (0..n).map(|i| i as f64).sum();
assert_abs_diff_eq!(result[0].re, expected_dc, epsilon = 1e-8);
}
#[test]
fn test_zoom_fft_resolves_close_frequencies() {
let n = 64;
let f1_norm = 0.15;
let f2_norm = 0.16;
let x: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * f1_norm * i as f64).sin() + (2.0 * PI * f2_norm * i as f64).sin())
.collect();
let contour =
SpiralContour::zoom_range(128, 0.12, 0.20, n).expect("Zoom contour should succeed");
let engine = EnhancedCZT::new(n, contour).expect("Engine should succeed");
let spectrum = engine.transform_real(&x).expect("Zoom CZT should succeed");
let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
let max_mag = magnitudes.iter().copied().fold(0.0_f64, f64::max);
assert!(max_mag > 1.0, "Zoom should find spectral energy");
}
#[test]
fn test_error_handling() {
assert!(SpiralContour::unit_circle(0).is_err());
assert!(SpiralContour::zoom_range(0, 0.0, 0.5, 64).is_err());
assert!(SpiralContour::zoom_range(16, 0.5, 0.3, 64).is_err());
assert!(SpiralContour::log_spiral(10, -1.0, 0.95, 0.0, 0.1).is_err());
assert!(czt_convolve(&[], &[1.0]).is_err());
assert!(adaptive_zoom_fft(&[], 0.0, 0.5, 8, 64).is_err());
}
}