Skip to main content

ad_plugins_rs/
fft.rs

1use std::sync::Arc;
2
3use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6use rustfft::FftPlanner;
7use rustfft::num_complex::Complex;
8
9/// FFT mode selection.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum FFTMode {
12    Rows1D,
13    Full2D,
14}
15
16/// FFT direction (forward or inverse transform).
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum FFTDirection {
19    Forward,
20    Inverse,
21}
22
23/// Configuration for FFT processing.
24pub struct FFTConfig {
25    pub mode: FFTMode,
26    pub direction: FFTDirection,
27    /// Zero out DC component (k=0) in the output magnitudes.
28    pub suppress_dc: bool,
29    /// Average N frames of magnitude. 0 or 1 means no averaging.
30    pub num_average: usize,
31}
32
33impl Default for FFTConfig {
34    fn default() -> Self {
35        Self {
36            mode: FFTMode::Rows1D,
37            direction: FFTDirection::Forward,
38            suppress_dc: false,
39            num_average: 0,
40        }
41    }
42}
43
44/// Compute 1D FFT magnitude for each row of a 2D array using rustfft.
45/// Returns a Float64 array with half the width (positive frequencies only, matching C++).
46/// Magnitudes are normalized by N (C++: `FFTAbsValue[j] = sqrt(...) / nTimeX`).
47pub fn fft_1d_rows(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
48    if src.dims.is_empty() {
49        return None;
50    }
51
52    let width = src.dims[0].size;
53    let height = if src.dims.len() >= 2 {
54        src.dims[1].size
55    } else {
56        1
57    };
58
59    if width == 0 {
60        return None;
61    }
62
63    let mut planner = FftPlanner::<f64>::new();
64    let fft = planner.plan_fft_forward(width);
65
66    // C++: nFreqX = nTimeX / 2 (only positive frequencies)
67    let n_freq = width / 2;
68    if n_freq == 0 {
69        return None;
70    }
71    let scale = 1.0 / width as f64;
72
73    let mut magnitudes = vec![0.0f64; n_freq * height];
74    let mut row_buf = vec![Complex::new(0.0, 0.0); width];
75
76    for row in 0..height {
77        // Fill complex buffer: real = pixel value, imag = 0
78        for i in 0..width {
79            row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
80        }
81
82        fft.process(&mut row_buf);
83
84        // Compute magnitudes (normalized by N, only first half)
85        for i in 0..n_freq {
86            magnitudes[row * n_freq + i] = row_buf[i].norm() * scale;
87        }
88
89        if suppress_dc {
90            magnitudes[row * n_freq] = 0.0;
91        }
92    }
93
94    let dims = if height > 1 {
95        vec![NDDimension::new(n_freq), NDDimension::new(height)]
96    } else {
97        vec![NDDimension::new(n_freq)]
98    };
99    let mut arr = NDArray::new(dims, NDDataType::Float64);
100    arr.data = NDDataBuffer::F64(magnitudes);
101    arr.unique_id = src.unique_id;
102    arr.timestamp = src.timestamp;
103    arr.attributes = src.attributes.clone();
104    Some(arr)
105}
106
107/// Compute 2D FFT magnitude using separable row-then-column FFT via rustfft.
108pub fn fft_2d(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
109    if src.dims.len() < 2 {
110        return None;
111    }
112
113    let w = src.dims[0].size;
114    let h = src.dims[1].size;
115
116    if w == 0 || h == 0 {
117        return None;
118    }
119
120    let mut planner = FftPlanner::<f64>::new();
121    let fft_row = planner.plan_fft_forward(w);
122    let fft_col = planner.plan_fft_forward(h);
123
124    // Step 1: Row FFTs — build a w×h complex buffer
125    let mut data = vec![Complex::new(0.0, 0.0); w * h];
126    let mut row_buf = vec![Complex::new(0.0, 0.0); w];
127
128    for row in 0..h {
129        for i in 0..w {
130            row_buf[i] = Complex::new(src.data.get_as_f64(row * w + i).unwrap_or(0.0), 0.0);
131        }
132        fft_row.process(&mut row_buf);
133        data[row * w..(row * w + w)].copy_from_slice(&row_buf);
134    }
135
136    // Step 2: Column FFTs
137    let mut col_buf = vec![Complex::new(0.0, 0.0); h];
138
139    for col in 0..w {
140        // Extract column
141        for row in 0..h {
142            col_buf[row] = data[row * w + col];
143        }
144        fft_col.process(&mut col_buf);
145        // Write back
146        for row in 0..h {
147            data[row * w + col] = col_buf[row];
148        }
149    }
150
151    // Step 3: Compute magnitudes (half spectrum, normalized by N*M)
152    let n_freq_x = w / 2;
153    let n_freq_y = h / 2;
154    if n_freq_x == 0 || n_freq_y == 0 {
155        return None;
156    }
157    let scale = 1.0 / (w * h) as f64;
158
159    let mut magnitudes = vec![0.0f64; n_freq_x * n_freq_y];
160    for fy in 0..n_freq_y {
161        for fx in 0..n_freq_x {
162            magnitudes[fy * n_freq_x + fx] = data[fy * w + fx].norm() * scale;
163        }
164    }
165
166    if suppress_dc {
167        magnitudes[0] = 0.0;
168    }
169
170    let dims = vec![NDDimension::new(n_freq_x), NDDimension::new(n_freq_y)];
171    let mut arr = NDArray::new(dims, NDDataType::Float64);
172    arr.data = NDDataBuffer::F64(magnitudes);
173    arr.unique_id = src.unique_id;
174    arr.timestamp = src.timestamp;
175    arr.attributes = src.attributes.clone();
176    Some(arr)
177}
178
179/// FFT processing engine with cached planner and optional magnitude averaging.
180#[derive(Default)]
181struct FFTParamIndices {
182    direction: Option<usize>,
183    suppress_dc: Option<usize>,
184    num_average: Option<usize>,
185    num_averaged: Option<usize>,
186    reset_average: Option<usize>,
187}
188
189pub struct FFTProcessor {
190    config: FFTConfig,
191    planner: FftPlanner<f64>,
192    /// Running average magnitude buffer.
193    avg_buffer: Option<Vec<f64>>,
194    /// Number of frames accumulated so far.
195    avg_count: usize,
196    /// Cached dimensions to detect changes.
197    cached_dims: Vec<usize>,
198    params: FFTParamIndices,
199}
200
201impl FFTProcessor {
202    pub fn new(mode: FFTMode) -> Self {
203        Self {
204            config: FFTConfig {
205                mode,
206                direction: FFTDirection::Forward,
207                suppress_dc: false,
208                num_average: 0,
209            },
210            planner: FftPlanner::new(),
211            avg_buffer: None,
212            avg_count: 0,
213            cached_dims: Vec::new(),
214            params: FFTParamIndices::default(),
215        }
216    }
217
218    pub fn with_config(config: FFTConfig) -> Self {
219        Self {
220            config,
221            planner: FftPlanner::new(),
222            avg_buffer: None,
223            avg_count: 0,
224            cached_dims: Vec::new(),
225            params: FFTParamIndices::default(),
226        }
227    }
228
229    /// Check if dimensions changed and reset averaging state if so.
230    fn check_dims_changed(&mut self, dims: &[NDDimension]) {
231        let current: Vec<usize> = dims.iter().map(|d| d.size).collect();
232        if current != self.cached_dims {
233            self.cached_dims = current;
234            self.avg_buffer = None;
235            self.avg_count = 0;
236        }
237    }
238
239    /// Compute FFT using cached planner for plan reuse across frames.
240    fn compute_fft(&mut self, src: &NDArray) -> Option<NDArray> {
241        let suppress_dc = self.config.suppress_dc;
242
243        match (self.config.mode, self.config.direction) {
244            (FFTMode::Rows1D, FFTDirection::Forward) => {
245                self.compute_fft_1d_rows_forward(src, suppress_dc)
246            }
247            (FFTMode::Rows1D, FFTDirection::Inverse) => {
248                self.compute_fft_1d_rows_inverse(src, suppress_dc)
249            }
250            (FFTMode::Full2D, FFTDirection::Forward) => {
251                self.compute_fft_2d_forward(src, suppress_dc)
252            }
253            (FFTMode::Full2D, FFTDirection::Inverse) => {
254                self.compute_fft_2d_inverse(src, suppress_dc)
255            }
256        }
257    }
258
259    fn compute_fft_1d_rows_forward(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
260        if src.dims.is_empty() {
261            return None;
262        }
263
264        let width = src.dims[0].size;
265        let height = if src.dims.len() >= 2 {
266            src.dims[1].size
267        } else {
268            1
269        };
270
271        if width == 0 {
272            return None;
273        }
274
275        let fft = self.planner.plan_fft_forward(width);
276
277        // C++: nFreqX = nTimeX / 2 (only positive frequencies)
278        let n_freq = width / 2;
279        if n_freq == 0 {
280            return None;
281        }
282        let scale = 1.0 / width as f64;
283
284        let mut magnitudes = vec![0.0f64; n_freq * height];
285        let mut row_buf = vec![Complex::new(0.0, 0.0); width];
286
287        for row in 0..height {
288            for i in 0..width {
289                row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
290            }
291            fft.process(&mut row_buf);
292            for i in 0..n_freq {
293                magnitudes[row * n_freq + i] = row_buf[i].norm() * scale;
294            }
295            if suppress_dc {
296                magnitudes[row * n_freq] = 0.0;
297            }
298        }
299
300        let dims = if height > 1 {
301            vec![NDDimension::new(n_freq), NDDimension::new(height)]
302        } else {
303            vec![NDDimension::new(n_freq)]
304        };
305        let mut arr = NDArray::new(dims, NDDataType::Float64);
306        arr.data = NDDataBuffer::F64(magnitudes);
307        arr.unique_id = src.unique_id;
308        arr.timestamp = src.timestamp;
309        arr.attributes = src.attributes.clone();
310        Some(arr)
311    }
312
313    fn compute_fft_1d_rows_inverse(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
314        if src.dims.is_empty() {
315            return None;
316        }
317
318        let width = src.dims[0].size;
319        let height = if src.dims.len() >= 2 {
320            src.dims[1].size
321        } else {
322            1
323        };
324
325        if width == 0 {
326            return None;
327        }
328
329        let fft = self.planner.plan_fft_inverse(width);
330        let scale = 1.0 / width as f64;
331
332        let mut magnitudes = vec![0.0f64; width * height];
333        let mut row_buf = vec![Complex::new(0.0, 0.0); width];
334
335        for row in 0..height {
336            for i in 0..width {
337                row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
338            }
339            if suppress_dc {
340                row_buf[0] = Complex::new(0.0, 0.0);
341            }
342            fft.process(&mut row_buf);
343            for (i, c) in row_buf.iter().enumerate() {
344                magnitudes[row * width + i] = c.norm() * scale;
345            }
346        }
347
348        let dims = src.dims.clone();
349        let mut arr = NDArray::new(dims, NDDataType::Float64);
350        arr.data = NDDataBuffer::F64(magnitudes);
351        arr.unique_id = src.unique_id;
352        arr.timestamp = src.timestamp;
353        arr.attributes = src.attributes.clone();
354        Some(arr)
355    }
356
357    fn compute_fft_2d_forward(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
358        if src.dims.len() < 2 {
359            return None;
360        }
361
362        let w = src.dims[0].size;
363        let h = src.dims[1].size;
364
365        if w == 0 || h == 0 {
366            return None;
367        }
368
369        let fft_row = self.planner.plan_fft_forward(w);
370        let fft_col = self.planner.plan_fft_forward(h);
371
372        let mut data = vec![Complex::new(0.0, 0.0); w * h];
373        let mut row_buf = vec![Complex::new(0.0, 0.0); w];
374
375        for row in 0..h {
376            for i in 0..w {
377                row_buf[i] = Complex::new(src.data.get_as_f64(row * w + i).unwrap_or(0.0), 0.0);
378            }
379            fft_row.process(&mut row_buf);
380            data[row * w..(row * w + w)].copy_from_slice(&row_buf);
381        }
382
383        let mut col_buf = vec![Complex::new(0.0, 0.0); h];
384        for col in 0..w {
385            for row in 0..h {
386                col_buf[row] = data[row * w + col];
387            }
388            fft_col.process(&mut col_buf);
389            for row in 0..h {
390                data[row * w + col] = col_buf[row];
391            }
392        }
393
394        // C++: nFreqX = nTimeX/2, nFreqY = nTimeY/2; normalize by N*M
395        let n_freq_x = w / 2;
396        let n_freq_y = h / 2;
397        if n_freq_x == 0 || n_freq_y == 0 {
398            return None;
399        }
400        let scale = 1.0 / (w * h) as f64;
401
402        let mut magnitudes = vec![0.0f64; n_freq_x * n_freq_y];
403        for fy in 0..n_freq_y {
404            for fx in 0..n_freq_x {
405                magnitudes[fy * n_freq_x + fx] = data[fy * w + fx].norm() * scale;
406            }
407        }
408
409        if suppress_dc {
410            magnitudes[0] = 0.0;
411        }
412
413        let dims = vec![NDDimension::new(n_freq_x), NDDimension::new(n_freq_y)];
414        let mut arr = NDArray::new(dims, NDDataType::Float64);
415        arr.data = NDDataBuffer::F64(magnitudes);
416        arr.unique_id = src.unique_id;
417        arr.timestamp = src.timestamp;
418        arr.attributes = src.attributes.clone();
419        Some(arr)
420    }
421
422    fn compute_fft_2d_inverse(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
423        if src.dims.len() < 2 {
424            return None;
425        }
426
427        let w = src.dims[0].size;
428        let h = src.dims[1].size;
429
430        if w == 0 || h == 0 {
431            return None;
432        }
433
434        let fft_row = self.planner.plan_fft_inverse(w);
435        let fft_col = self.planner.plan_fft_inverse(h);
436        let scale = 1.0 / (w * h) as f64;
437
438        let mut data = vec![Complex::new(0.0, 0.0); w * h];
439        for i in 0..w * h {
440            data[i] = Complex::new(src.data.get_as_f64(i).unwrap_or(0.0), 0.0);
441        }
442
443        if suppress_dc {
444            data[0] = Complex::new(0.0, 0.0);
445        }
446
447        let mut col_buf = vec![Complex::new(0.0, 0.0); h];
448        for col in 0..w {
449            for row in 0..h {
450                col_buf[row] = data[row * w + col];
451            }
452            fft_col.process(&mut col_buf);
453            for row in 0..h {
454                data[row * w + col] = col_buf[row];
455            }
456        }
457
458        let mut row_buf = vec![Complex::new(0.0, 0.0); w];
459        for row in 0..h {
460            row_buf.copy_from_slice(&data[row * w..(row * w + w)]);
461            fft_row.process(&mut row_buf);
462            data[row * w..(row * w + w)].copy_from_slice(&row_buf);
463        }
464
465        let magnitudes: Vec<f64> = data.iter().map(|c| c.norm() * scale).collect();
466
467        let dims = vec![NDDimension::new(w), NDDimension::new(h)];
468        let mut arr = NDArray::new(dims, NDDataType::Float64);
469        arr.data = NDDataBuffer::F64(magnitudes);
470        arr.unique_id = src.unique_id;
471        arr.timestamp = src.timestamp;
472        arr.attributes = src.attributes.clone();
473        Some(arr)
474    }
475
476    /// Apply magnitude averaging using exponential moving average (matching C++).
477    ///
478    /// C++: `FFTAbsValue_[j] = FFTAbsValue_[j] * oldFraction + new[j] * newFraction`
479    /// where `oldFraction = 1 - 1/numAveraged`, `newFraction = 1/numAveraged`.
480    fn apply_averaging(&mut self, magnitudes: &[f64]) -> Vec<f64> {
481        let num_avg = self.config.num_average;
482        if num_avg <= 1 {
483            return magnitudes.to_vec();
484        }
485
486        let buf = self
487            .avg_buffer
488            .get_or_insert_with(|| vec![0.0; magnitudes.len()]);
489
490        // Reset if buffer size changed
491        if buf.len() != magnitudes.len() {
492            *buf = vec![0.0; magnitudes.len()];
493            self.avg_count = 0;
494        }
495
496        self.avg_count += 1;
497        // Cap at num_average for the weighting
498        let n = self.avg_count.min(num_avg) as f64;
499        let new_fraction = 1.0 / n;
500        let old_fraction = 1.0 - new_fraction;
501
502        // C++ exponential moving average
503        for (b, &m) in buf.iter_mut().zip(magnitudes.iter()) {
504            *b = *b * old_fraction + m * new_fraction;
505        }
506
507        buf.clone()
508    }
509}
510
511impl NDPluginProcess for FFTProcessor {
512    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
513        use ad_core_rs::plugin::runtime::ParamUpdate;
514
515        self.check_dims_changed(&array.dims);
516
517        let result = self.compute_fft(array);
518        let mut updates = Vec::new();
519        if let Some(idx) = self.params.num_averaged {
520            updates.push(ParamUpdate::int32(idx, self.avg_count as i32));
521        }
522
523        match result {
524            Some(mut out) => {
525                if self.config.num_average > 1 {
526                    if let NDDataBuffer::F64(ref mags) = out.data {
527                        let averaged = self.apply_averaging(mags);
528                        out.data = NDDataBuffer::F64(averaged);
529                    }
530                }
531                let mut r = ProcessResult::arrays(vec![Arc::new(out)]);
532                r.param_updates = updates;
533                r
534            }
535            None => ProcessResult::sink(updates),
536        }
537    }
538
539    fn plugin_type(&self) -> &str {
540        "NDPluginFFT"
541    }
542
543    fn register_params(
544        &mut self,
545        base: &mut asyn_rs::port::PortDriverBase,
546    ) -> asyn_rs::error::AsynResult<()> {
547        use asyn_rs::param::ParamType;
548        base.create_param("FFT_TIME_PER_POINT", ParamType::Float64)?;
549        base.create_param("FFT_TIME_AXIS", ParamType::Float64Array)?;
550        base.create_param("FFT_FREQ_AXIS", ParamType::Float64Array)?;
551        base.create_param("FFT_DIRECTION", ParamType::Int32)?;
552        base.create_param("FFT_SUPPRESS_DC", ParamType::Int32)?;
553        base.create_param("FFT_NUM_AVERAGE", ParamType::Int32)?;
554        base.create_param("FFT_NUM_AVERAGED", ParamType::Int32)?;
555        base.create_param("FFT_RESET_AVERAGE", ParamType::Int32)?;
556        base.create_param("FFT_TIME_SERIES", ParamType::Float64Array)?;
557        base.create_param("FFT_REAL", ParamType::Float64Array)?;
558        base.create_param("FFT_IMAGINARY", ParamType::Float64Array)?;
559        base.create_param("FFT_ABS_VALUE", ParamType::Float64Array)?;
560
561        self.params.direction = base.find_param("FFT_DIRECTION");
562        self.params.suppress_dc = base.find_param("FFT_SUPPRESS_DC");
563        self.params.num_average = base.find_param("FFT_NUM_AVERAGE");
564        self.params.num_averaged = base.find_param("FFT_NUM_AVERAGED");
565        self.params.reset_average = base.find_param("FFT_RESET_AVERAGE");
566        Ok(())
567    }
568
569    fn on_param_change(
570        &mut self,
571        reason: usize,
572        params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
573    ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
574        if Some(reason) == self.params.direction {
575            self.config.direction = if params.value.as_i32() == 0 {
576                FFTDirection::Forward
577            } else {
578                FFTDirection::Inverse
579            };
580        } else if Some(reason) == self.params.suppress_dc {
581            self.config.suppress_dc = params.value.as_i32() != 0;
582        } else if Some(reason) == self.params.num_average {
583            self.config.num_average = params.value.as_i32().max(0) as usize;
584        } else if Some(reason) == self.params.reset_average {
585            if params.value.as_i32() != 0 {
586                self.avg_buffer = None;
587                self.avg_count = 0;
588            }
589        }
590        ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    #[test]
599    fn test_fft_1d_dc() {
600        // Constant signal: DC component should dominate
601        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
602        if let NDDataBuffer::F64(ref mut v) = arr.data {
603            for i in 0..8 {
604                v[i] = 1.0;
605            }
606        }
607
608        let result = fft_1d_rows(&arr, false).unwrap();
609        // Output is half spectrum: N/2 = 4 bins
610        assert_eq!(result.dims[0].size, 4);
611        if let NDDataBuffer::F64(ref v) = result.data {
612            // DC component normalized by N: 8/8 = 1.0
613            assert!((v[0] - 1.0).abs() < 1e-10);
614            // Other components should be ~0
615            assert!(v[1].abs() < 1e-10);
616        }
617    }
618
619    #[test]
620    fn test_fft_1d_sine() {
621        // Sine wave at frequency 1: peak at k=1 and k=N-1
622        let n = 16;
623        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
624        if let NDDataBuffer::F64(ref mut v) = arr.data {
625            for i in 0..n {
626                v[i] = (2.0 * std::f64::consts::PI * i as f64 / n as f64).sin();
627            }
628        }
629
630        let result = fft_1d_rows(&arr, false).unwrap();
631        // Output is N/2 = 8 bins
632        assert_eq!(result.dims[0].size, 8);
633        if let NDDataBuffer::F64(ref v) = result.data {
634            // DC should be ~0
635            assert!(v[0].abs() < 1e-10);
636            // Peak at k=1, normalized by N: magnitude = N/2 / N = 0.5
637            assert!((v[1] - 0.5).abs() < 1e-10);
638            // k=2 should be small
639            assert!(v[2].abs() < 1e-10);
640        }
641    }
642
643    #[test]
644    fn test_fft_2d_dimensions() {
645        let arr = NDArray::new(
646            vec![NDDimension::new(4), NDDimension::new(4)],
647            NDDataType::UInt8,
648        );
649        let result = fft_2d(&arr, false).unwrap();
650        // Half spectrum: 4/2 x 4/2 = 2x2
651        assert_eq!(result.dims[0].size, 2);
652        assert_eq!(result.dims[1].size, 2);
653        assert_eq!(result.data.data_type(), NDDataType::Float64);
654    }
655
656    #[test]
657    fn test_fft_1d_suppress_dc() {
658        // Constant signal: DC component should be suppressed
659        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
660        if let NDDataBuffer::F64(ref mut v) = arr.data {
661            for i in 0..8 {
662                v[i] = 1.0;
663            }
664        }
665
666        let result = fft_1d_rows(&arr, true).unwrap();
667        if let NDDataBuffer::F64(ref v) = result.data {
668            // DC component should be zeroed out
669            assert!((v[0]).abs() < 1e-15);
670            // Other components should still be ~0 for constant signal
671            assert!(v[1].abs() < 1e-10);
672        } else {
673            panic!("expected F64 data");
674        }
675    }
676
677    #[test]
678    fn test_fft_2d_suppress_dc() {
679        // 4x4 constant array, suppress_dc should zero out [0,0]
680        let mut arr = NDArray::new(
681            vec![NDDimension::new(4), NDDimension::new(4)],
682            NDDataType::Float64,
683        );
684        if let NDDataBuffer::F64(ref mut v) = arr.data {
685            for val in v.iter_mut() {
686                *val = 3.0;
687            }
688        }
689
690        let result = fft_2d(&arr, true).unwrap();
691        if let NDDataBuffer::F64(ref v) = result.data {
692            // DC at [0,0] should be zeroed
693            assert!((v[0]).abs() < 1e-15);
694        } else {
695            panic!("expected F64 data");
696        }
697    }
698
699    #[test]
700    fn test_fft_2d_known_dc() {
701        // 4x4 constant=2.0 => DC = 4*4*2 = 32, normalized by 4*4 = 16 => 2.0
702        let mut arr = NDArray::new(
703            vec![NDDimension::new(4), NDDimension::new(4)],
704            NDDataType::Float64,
705        );
706        if let NDDataBuffer::F64(ref mut v) = arr.data {
707            for val in v.iter_mut() {
708                *val = 2.0;
709            }
710        }
711
712        let result = fft_2d(&arr, false).unwrap();
713        // Half spectrum: 2x2
714        assert_eq!(result.dims[0].size, 2);
715        assert_eq!(result.dims[1].size, 2);
716        if let NDDataBuffer::F64(ref v) = result.data {
717            // DC normalized by N*M: 32 / 16 = 2.0
718            assert!((v[0] - 2.0).abs() < 1e-10, "DC = {}, expected 2", v[0]);
719            // All other bins should be ~0
720            for i in 1..v.len() {
721                assert!(v[i].abs() < 1e-10, "bin {} = {}, expected ~0", i, v[i]);
722            }
723        } else {
724            panic!("expected F64 data");
725        }
726    }
727
728    #[test]
729    fn test_fft_1d_known_cosine_peaks() {
730        // Cosine at frequency 3 in N=16: peaks at k=3 and k=N-3=13
731        let n = 16;
732        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
733        if let NDDataBuffer::F64(ref mut v) = arr.data {
734            for i in 0..n {
735                v[i] = (2.0 * std::f64::consts::PI * 3.0 * i as f64 / n as f64).cos();
736            }
737        }
738
739        let result = fft_1d_rows(&arr, false).unwrap();
740        // Half spectrum: 8 bins
741        assert_eq!(result.dims[0].size, 8);
742        if let NDDataBuffer::F64(ref v) = result.data {
743            // DC should be ~0
744            assert!(v[0].abs() < 1e-10);
745            // k=3 should have magnitude N/2 / N = 8/16 = 0.5
746            assert!(
747                (v[3] - 0.5).abs() < 1e-10,
748                "k=3 magnitude = {}, expected 0.5",
749                v[3]
750            );
751            // Other bins in first half should be ~0
752            for k in [1, 2, 4, 5, 6, 7] {
753                assert!(
754                    v[k].abs() < 1e-10,
755                    "k={} magnitude = {}, expected ~0",
756                    k,
757                    v[k]
758                );
759            }
760        } else {
761            panic!("expected F64 data");
762        }
763    }
764
765    #[test]
766    fn test_processor_with_config() {
767        let config = FFTConfig {
768            mode: FFTMode::Rows1D,
769            direction: FFTDirection::Forward,
770            suppress_dc: true,
771            num_average: 0,
772        };
773        let mut proc = FFTProcessor::with_config(config);
774        let pool = NDArrayPool::new(0);
775
776        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
777        if let NDDataBuffer::F64(ref mut v) = arr.data {
778            for i in 0..8 {
779                v[i] = 5.0;
780            }
781        }
782
783        let result = proc.process_array(&arr, &pool);
784        assert_eq!(result.output_arrays.len(), 1);
785        if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
786            // suppress_dc: DC should be 0
787            assert!(v[0].abs() < 1e-15);
788        } else {
789            panic!("expected F64 data");
790        }
791    }
792
793    #[test]
794    fn test_processor_averaging() {
795        let config = FFTConfig {
796            mode: FFTMode::Rows1D,
797            direction: FFTDirection::Forward,
798            suppress_dc: false,
799            num_average: 2,
800        };
801        let mut proc = FFTProcessor::with_config(config);
802        let pool = NDArrayPool::new(0);
803
804        // Frame 1: constant = 2.0 => DC magnitude (normalized) = 2.0
805        let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
806        if let NDDataBuffer::F64(ref mut v) = arr1.data {
807            for i in 0..8 {
808                v[i] = 2.0;
809            }
810        }
811
812        // Frame 2: constant = 4.0 => DC magnitude (normalized) = 4.0
813        let mut arr2 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
814        if let NDDataBuffer::F64(ref mut v) = arr2.data {
815            for i in 0..8 {
816                v[i] = 4.0;
817            }
818        }
819
820        let r1 = proc.process_array(&arr1, &pool);
821        assert_eq!(r1.output_arrays.len(), 1);
822        // After 1 frame: exponential avg with N=1, so output = 2.0
823        if let NDDataBuffer::F64(ref v) = r1.output_arrays[0].data {
824            assert!((v[0] - 2.0).abs() < 1e-10, "partial avg DC = {}", v[0]);
825        }
826
827        let r2 = proc.process_array(&arr2, &pool);
828        assert_eq!(r2.output_arrays.len(), 1);
829        // After 2 frames: exp avg = 2.0*(1-1/2) + 4.0*(1/2) = 1.0 + 2.0 = 3.0
830        if let NDDataBuffer::F64(ref v) = r2.output_arrays[0].data {
831            assert!((v[0] - 3.0).abs() < 1e-10, "averaged DC = {}", v[0]);
832        }
833    }
834
835    #[test]
836    fn test_processor_averaging_dimension_change_resets() {
837        let config = FFTConfig {
838            mode: FFTMode::Rows1D,
839            direction: FFTDirection::Forward,
840            suppress_dc: false,
841            num_average: 3,
842        };
843        let mut proc = FFTProcessor::with_config(config);
844        let pool = NDArrayPool::new(0);
845
846        // Frame 1: width=8
847        let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
848        if let NDDataBuffer::F64(ref mut v) = arr1.data {
849            for i in 0..8 {
850                v[i] = 1.0;
851            }
852        }
853        let _ = proc.process_array(&arr1, &pool);
854        assert_eq!(proc.avg_count, 1);
855
856        // Frame 2: width=4 — dimension change should reset
857        let mut arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
858        if let NDDataBuffer::F64(ref mut v) = arr2.data {
859            for i in 0..4 {
860                v[i] = 1.0;
861            }
862        }
863        let _ = proc.process_array(&arr2, &pool);
864        // After dimension change, avg_count should be 1 (reset + one new frame)
865        assert_eq!(proc.avg_count, 1);
866    }
867
868    #[test]
869    fn test_fft_1d_multirow() {
870        // 2 rows, each a different constant
871        let w = 4;
872        let h = 2;
873        let mut arr = NDArray::new(
874            vec![NDDimension::new(w), NDDimension::new(h)],
875            NDDataType::Float64,
876        );
877        if let NDDataBuffer::F64(ref mut v) = arr.data {
878            // Row 0: all 1.0
879            for i in 0..w {
880                v[i] = 1.0;
881            }
882            // Row 1: all 3.0
883            for i in w..2 * w {
884                v[i] = 3.0;
885            }
886        }
887
888        let result = fft_1d_rows(&arr, false).unwrap();
889        let n_freq = w / 2; // half spectrum
890        assert_eq!(result.dims[0].size, n_freq);
891        if let NDDataBuffer::F64(ref v) = result.data {
892            // Row 0 DC = 4*1/4 = 1.0 (normalized by N=4)
893            assert!((v[0] - 1.0).abs() < 1e-10);
894            // Row 1 DC = 4*3/4 = 3.0
895            assert!((v[n_freq] - 3.0).abs() < 1e-10);
896        } else {
897            panic!("expected F64 data");
898        }
899    }
900
901    #[test]
902    fn test_inverse_fft_1d() {
903        // IFFT of a known forward FFT should give back the original magnitudes
904        // For a real constant signal, forward FFT gives [N, 0, 0, ...0]
905        // IFFT of [N, 0, ...0] (real input) should give constant = 1.0 for each sample
906        let n = 8;
907        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
908        if let NDDataBuffer::F64(ref mut v) = arr.data {
909            v[0] = 8.0; // DC = N
910            // rest are 0
911        }
912
913        let config = FFTConfig {
914            mode: FFTMode::Rows1D,
915            direction: FFTDirection::Inverse,
916            suppress_dc: false,
917            num_average: 0,
918        };
919        let mut proc = FFTProcessor::with_config(config);
920        let pool = NDArrayPool::new(0);
921
922        let result = proc.process_array(&arr, &pool);
923        assert_eq!(result.output_arrays.len(), 1);
924        if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
925            // Each sample should be magnitude 1.0 (8/8 = 1.0 after normalization)
926            for i in 0..n {
927                assert!(
928                    (v[i] - 1.0).abs() < 1e-10,
929                    "sample {} = {}, expected 1.0",
930                    i,
931                    v[i]
932                );
933            }
934        } else {
935            panic!("expected F64 data");
936        }
937    }
938
939    #[test]
940    fn test_fft_preserves_metadata() {
941        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
942        arr.unique_id = 42;
943        if let NDDataBuffer::F64(ref mut v) = arr.data {
944            v[0] = 1.0;
945        }
946
947        let result = fft_1d_rows(&arr, false).unwrap();
948        assert_eq!(result.unique_id, 42);
949        assert_eq!(result.timestamp, arr.timestamp);
950    }
951}