1use crate::{next_fast_len, FFTError, FFTResult};
8use ndarray::{s, Array, Array1, ArrayBase, ArrayD, Axis, Data, Dimension, RemoveAxis, Zip};
9use num_complex::Complex;
10use std::f64::consts::PI;
11
12#[allow(dead_code)]
25pub fn czt_points(
26 m: usize,
27 a: Option<Complex<f64>>,
28 w: Option<Complex<f64>>,
29) -> Array1<Complex<f64>> {
30 let a = a.unwrap_or(Complex::new(1.0, 0.0));
31 let k = Array1::linspace(0.0, (m - 1) as f64, m);
32
33 if let Some(w) = w {
34 k.mapv(|ki| a * w.powf(-ki))
36 } else {
37 k.mapv(|ki| a * (Complex::new(0.0, 2.0 * PI * ki / m as f64)).exp())
39 }
40}
41
42#[derive(Clone)]
46pub struct CZT {
47 n: usize,
48 m: usize,
49 w: Option<Complex<f64>>,
50 a: Complex<f64>,
51 nfft: usize,
52 awk2: Array1<Complex<f64>>,
53 fwk2: Array1<Complex<f64>>,
54 wk2: Array1<Complex<f64>>,
55}
56
57impl CZT {
58 pub fn new(
66 n: usize,
67 m: Option<usize>,
68 w: Option<Complex<f64>>,
69 a: Option<Complex<f64>>,
70 ) -> FFTResult<Self> {
71 if n < 1 {
72 return Err(FFTError::ValueError("n must be positive".to_string()));
73 }
74
75 let m = m.unwrap_or(n);
76 if m < 1 {
77 return Err(FFTError::ValueError("m must be positive".to_string()));
78 }
79
80 let a = a.unwrap_or(Complex::new(1.0, 0.0));
81 let max_size = n.max(m);
82 let k = Array1::linspace(0.0, (max_size - 1) as f64, max_size);
83
84 let (w, wk2) = if let Some(w) = w {
85 let wk2 = k.mapv(|ki| w.powf(ki * ki / 2.0));
87 (Some(w), wk2)
88 } else {
89 let w = (-2.0 * PI * Complex::<f64>::i() / m as f64).exp();
91 let wk2 = k.mapv(|ki| {
92 let ki_i64 = ki as i64;
93 let phase = -(PI * ((ki_i64 * ki_i64) % (2 * m as i64)) as f64) / m as f64;
94 Complex::from_polar(1.0, phase)
95 });
96 (Some(w), wk2)
97 };
98
99 let nfft = next_fast_len(n + m - 1, false);
101
102 let awk2: Array1<Complex<f64>> = (0..n).map(|k| a.powf(-(k as f64)) * wk2[k]).collect();
104
105 let mut chirp_vec = vec![Complex::new(0.0, 0.0); nfft];
107
108 for i in 1..n {
110 chirp_vec[n - 1 - i] = Complex::new(1.0, 0.0) / wk2[i];
111 }
112 for i in 0..m {
113 chirp_vec[n - 1 + i] = Complex::new(1.0, 0.0) / wk2[i];
114 }
115
116 let chirp_array = Array1::from_vec(chirp_vec);
117 let fwk2_vec = crate::fft::fft(&chirp_array.to_vec(), None)?;
118 let fwk2 = Array1::from_vec(fwk2_vec);
119
120 Ok(CZT {
121 n,
122 m,
123 w,
124 a,
125 nfft,
126 awk2,
127 fwk2,
128 wk2: wk2.slice(s![..m]).to_owned(),
129 })
130 }
131
132 pub fn points(&self) -> Array1<Complex<f64>> {
134 czt_points(self.m, Some(self.a), self.w)
135 }
136
137 pub fn transform<S, D>(
143 &self,
144 x: &ArrayBase<S, D>,
145 axis: Option<i32>,
146 ) -> FFTResult<ArrayD<Complex<f64>>>
147 where
148 S: Data<Elem = Complex<f64>>,
149 D: Dimension + RemoveAxis,
150 {
151 let ndim = x.ndim();
152 let axis = if let Some(ax) = axis {
153 if ax < 0 {
154 let ax_pos = (ndim as i32 + ax) as usize;
155 if ax_pos >= ndim {
156 return Err(FFTError::ValueError("Invalid axis".to_string()));
157 }
158 ax_pos
159 } else {
160 ax as usize
161 }
162 } else {
163 ndim - 1
164 };
165
166 let axis_len = x.shape()[axis];
167 if axis_len != self.n {
168 return Err(FFTError::ValueError(format!(
169 "Input size ({}) doesn't match CZT size ({})",
170 axis_len, self.n
171 )));
172 }
173
174 let mut outputshape = x.shape().to_vec();
176 outputshape[axis] = self.m;
177 let mut result = Array::<Complex<f64>, _>::zeros(outputshape).into_dyn();
178
179 if x.ndim() == 1 {
182 let x_1d: Array1<Complex<f64>> = x
183 .to_owned()
184 .into_shape_with_order(x.len())
185 .map_err(|e| {
186 FFTError::ComputationError(format!("Failed to reshape input array to 1D: {e}"))
187 })?
188 .into_dimensionality()
189 .map_err(|e| {
190 FFTError::ComputationError(format!(
191 "Failed to convert array dimensionality: {e}"
192 ))
193 })?;
194 let y = self.transform_1d(&x_1d)?;
195 return Ok(y.into_dyn());
196 }
197
198 for (i, x_slice) in x.axis_iter(Axis(axis)).enumerate() {
200 let x_1d: Array1<Complex<f64>> = x_slice
202 .to_owned()
203 .into_shape_with_order(x_slice.len())
204 .map_err(|e| {
205 FFTError::ComputationError(format!("Failed to reshape slice to 1D array: {e}"))
206 })?;
207 let y = self.transform_1d(&x_1d)?;
208
209 match result.ndim() {
211 2 => {
212 if axis == 0 {
213 let mut result_slice = result.slice_mut(s![i, ..]);
214 result_slice.assign(&y);
215 } else {
216 let mut result_slice = result.slice_mut(s![.., i]);
217 result_slice.assign(&y);
218 }
219 }
220 _ => {
221 return Err(FFTError::ValueError(
223 "CZT currently only supports 1D and 2D arrays".to_string(),
224 ));
225 }
226 }
227 }
228
229 Ok(result)
230 }
231
232 fn transform_1d(&self, x: &Array1<Complex<f64>>) -> FFTResult<Array1<Complex<f64>>> {
234 if x.len() != self.n {
235 return Err(FFTError::ValueError(format!(
236 "Input size ({}) doesn't match CZT size ({})",
237 x.len(),
238 self.n
239 )));
240 }
241
242 let x_weighted: Array1<Complex<f64>> = Zip::from(x)
244 .and(&self.awk2)
245 .map_collect(|&xi, &awki| xi * awki);
246
247 let mut padded = Array1::zeros(self.nfft);
249 padded.slice_mut(s![..self.n]).assign(&x_weighted);
250
251 let x_fft_vec = crate::fft::fft(&padded.to_vec(), None)?;
253 let x_fft = Array1::from_vec(x_fft_vec);
254
255 let product: Array1<Complex<f64>> = Zip::from(&x_fft)
257 .and(&self.fwk2)
258 .map_collect(|&xi, &fi| xi * fi);
259
260 let y_full_vec = crate::fft::ifft(&product.to_vec(), None)?;
262 let y_full = Array1::from_vec(y_full_vec);
263
264 let y_slice = y_full.slice(s![self.n - 1..self.n - 1 + self.m]);
266 let result: Array1<Complex<f64>> = Zip::from(&y_slice)
267 .and(&self.wk2)
268 .map_collect(|&yi, &wki| yi * wki);
269
270 Ok(result)
271 }
272}
273
274#[allow(dead_code)]
283pub fn czt<S, D>(
284 x: &ArrayBase<S, D>,
285 m: Option<usize>,
286 w: Option<Complex<f64>>,
287 a: Option<Complex<f64>>,
288 axis: Option<i32>,
289) -> FFTResult<ArrayD<Complex<f64>>>
290where
291 S: Data<Elem = Complex<f64>>,
292 D: Dimension + RemoveAxis,
293{
294 let axis_actual = if let Some(ax) = axis {
295 if ax < 0 {
296 (x.ndim() as i32 + ax) as usize
297 } else {
298 ax as usize
299 }
300 } else {
301 x.ndim() - 1
302 };
303
304 let n = x.shape()[axis_actual];
305 let transform = CZT::new(n, m, w, a)?;
306 transform.transform(x, axis)
307}
308
309#[allow(dead_code)]
320pub fn zoom_fft<S, D>(
321 x: &ArrayBase<S, D>,
322 m: usize,
323 f0: f64,
324 f1: f64,
325 oversampling: Option<f64>,
326) -> FFTResult<ArrayD<Complex<f64>>>
327where
328 S: Data<Elem = Complex<f64>>,
329 D: Dimension + RemoveAxis,
330{
331 if !(0.0..=1.0).contains(&f0) || !(0.0..=1.0).contains(&f1) {
332 return Err(FFTError::ValueError(
333 "Frequencies must be in range [0, 1]".to_string(),
334 ));
335 }
336
337 if f0 >= f1 {
338 return Err(FFTError::ValueError("f0 must be less than f1".to_string()));
339 }
340
341 let oversampling = oversampling.unwrap_or(2.0);
342 if oversampling < 1.0 {
343 return Err(FFTError::ValueError(
344 "Oversampling must be >= 1".to_string(),
345 ));
346 }
347
348 let ndim = x.ndim();
349 let axis = ndim - 1;
350 let n = x.shape()[axis];
351
352 let k0_float = f0 * n as f64 * oversampling;
354 let k1_float = f1 * n as f64 * oversampling;
355 let step = (k1_float - k0_float) / (m - 1) as f64;
356
357 let phi = 2.0 * PI * k0_float / (n as f64 * oversampling);
358 let a = Complex::from_polar(1.0, phi);
359
360 let theta = -2.0 * PI * step / (n as f64 * oversampling);
361 let w = Complex::from_polar(1.0, theta);
362
363 czt(x, Some(m), Some(w), Some(a), Some(axis as i32))
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use approx::assert_abs_diff_eq;
370
371 #[test]
372 fn test_czt_points() {
373 let points = czt_points(4, None, None);
375 assert_eq!(points.len(), 4);
376
377 for p in points.iter() {
379 assert_abs_diff_eq!(p.norm(), 1.0, epsilon = 1e-10);
380 }
381
382 let a = Complex::new(0.8, 0.0);
384 let w = Complex::from_polar(0.95, 0.1);
385 let points = czt_points(5, Some(a), Some(w));
386 assert_eq!(points.len(), 5);
387 assert!((points[0] - a).norm() < 1e-10);
388 }
389
390 #[test]
391 fn test_czt_as_fft() {
392 let n = 8;
394 let x: Array1<Complex<f64>> = Array1::linspace(0.0, 7.0, n).mapv(|v| Complex::new(v, 0.0));
395
396 let czt_result = czt(&x.view(), None, None, None, None)
397 .expect("CZT computation should succeed for test data");
398
399 assert_eq!(czt_result.ndim(), 1);
401 let czt_result_1d: Array1<Complex<f64>> = czt_result
402 .into_dimensionality()
403 .expect("CZT result should convert to 1D array");
404
405 let fft_result_vec = crate::fft::fft(&x.to_vec(), None)
406 .expect("FFT computation should succeed for test data");
407 let fft_result = Array1::from_vec(fft_result_vec);
408
409 for i in 0..n {
410 assert!((czt_result_1d[i].re - fft_result[i].re).abs() < 1e-10);
411 assert!((czt_result_1d[i].im - fft_result[i].im).abs() < 1e-10);
412 }
413 }
414
415 #[test]
416 fn test_zoom_fft() {
417 let n = 64;
419 let t: Array1<f64> = Array1::linspace(0.0, 1.0, n);
420 let x: Array1<Complex<f64>> = t.mapv(|ti| {
421 let s = (2.0 * PI * 5.0 * ti).sin(); Complex::new(s, 0.0)
423 });
424
425 let m = 16;
427 let zoom_result =
428 zoom_fft(&x.view(), m, 0.0, 0.5, None).expect("Zoom FFT should succeed for test data");
429
430 assert_eq!(zoom_result.ndim(), 1);
432 let zoom_result_1d: Array1<Complex<f64>> = zoom_result
433 .into_dimensionality()
434 .expect("Zoom FFT result should convert to 1D array");
435 assert_eq!(zoom_result_1d.len(), m);
436
437 let has_nonzero = zoom_result_1d.iter().any(|&c| c.norm() > 1e-10);
439 assert!(has_nonzero, "Zoom FFT should produce some non-zero values");
440 }
441}