Skip to main content

ad_plugins/
fft.rs

1use std::sync::Arc;
2
3use ad_core::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
4use ad_core::ndarray_pool::NDArrayPool;
5use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
6use rustfft::num_complex::Complex;
7use rustfft::FftPlanner;
8
9/// FFT mode selection.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum FFTMode {
12    Rows1D,
13    Full2D,
14}
15
16/// FFT direction (forward or inverse transform).
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum FFTDirection {
19    Forward,
20    Inverse,
21}
22
23/// Configuration for FFT processing.
24pub struct FFTConfig {
25    pub mode: FFTMode,
26    pub direction: FFTDirection,
27    /// Zero out DC component (k=0) in the output magnitudes.
28    pub suppress_dc: bool,
29    /// Average N frames of magnitude. 0 or 1 means no averaging.
30    pub num_average: usize,
31}
32
33impl Default for FFTConfig {
34    fn default() -> Self {
35        Self {
36            mode: FFTMode::Rows1D,
37            direction: FFTDirection::Forward,
38            suppress_dc: false,
39            num_average: 0,
40        }
41    }
42}
43
44/// Compute 1D FFT magnitude for each row of a 2D array using rustfft.
45/// Returns a Float64 array with the same dimensions.
46pub fn fft_1d_rows(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
47    if src.dims.is_empty() {
48        return None;
49    }
50
51    let width = src.dims[0].size;
52    let height = if src.dims.len() >= 2 { src.dims[1].size } else { 1 };
53
54    if width == 0 {
55        return None;
56    }
57
58    let mut planner = FftPlanner::<f64>::new();
59    let fft = planner.plan_fft_forward(width);
60
61    let mut magnitudes = vec![0.0f64; width * height];
62    let mut row_buf = vec![Complex::new(0.0, 0.0); width];
63
64    for row in 0..height {
65        // Fill complex buffer: real = pixel value, imag = 0
66        for i in 0..width {
67            row_buf[i] = Complex::new(
68                src.data.get_as_f64(row * width + i).unwrap_or(0.0),
69                0.0,
70            );
71        }
72
73        fft.process(&mut row_buf);
74
75        // Compute magnitudes
76        for (i, c) in row_buf.iter().enumerate() {
77            magnitudes[row * width + i] = c.norm();
78        }
79
80        if suppress_dc {
81            magnitudes[row * width] = 0.0;
82        }
83    }
84
85    let dims = src.dims.clone();
86    let mut arr = NDArray::new(dims, NDDataType::Float64);
87    arr.data = NDDataBuffer::F64(magnitudes);
88    arr.unique_id = src.unique_id;
89    arr.timestamp = src.timestamp;
90    arr.attributes = src.attributes.clone();
91    Some(arr)
92}
93
94/// Compute 2D FFT magnitude using separable row-then-column FFT via rustfft.
95pub fn fft_2d(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
96    if src.dims.len() < 2 {
97        return None;
98    }
99
100    let w = src.dims[0].size;
101    let h = src.dims[1].size;
102
103    if w == 0 || h == 0 {
104        return None;
105    }
106
107    let mut planner = FftPlanner::<f64>::new();
108    let fft_row = planner.plan_fft_forward(w);
109    let fft_col = planner.plan_fft_forward(h);
110
111    // Step 1: Row FFTs — build a w×h complex buffer
112    let mut data = vec![Complex::new(0.0, 0.0); w * h];
113    let mut row_buf = vec![Complex::new(0.0, 0.0); w];
114
115    for row in 0..h {
116        for i in 0..w {
117            row_buf[i] = Complex::new(
118                src.data.get_as_f64(row * w + i).unwrap_or(0.0),
119                0.0,
120            );
121        }
122        fft_row.process(&mut row_buf);
123        data[row * w..(row * w + w)].copy_from_slice(&row_buf);
124    }
125
126    // Step 2: Column FFTs
127    let mut col_buf = vec![Complex::new(0.0, 0.0); h];
128
129    for col in 0..w {
130        // Extract column
131        for row in 0..h {
132            col_buf[row] = data[row * w + col];
133        }
134        fft_col.process(&mut col_buf);
135        // Write back
136        for row in 0..h {
137            data[row * w + col] = col_buf[row];
138        }
139    }
140
141    // Step 3: Compute magnitudes
142    let mut magnitudes: Vec<f64> = data.iter().map(|c| c.norm()).collect();
143
144    if suppress_dc {
145        magnitudes[0] = 0.0;
146    }
147
148    let dims = vec![NDDimension::new(w), NDDimension::new(h)];
149    let mut arr = NDArray::new(dims, NDDataType::Float64);
150    arr.data = NDDataBuffer::F64(magnitudes);
151    arr.unique_id = src.unique_id;
152    arr.timestamp = src.timestamp;
153    arr.attributes = src.attributes.clone();
154    Some(arr)
155}
156
157/// FFT processing engine with cached planner and optional magnitude averaging.
158pub struct FFTProcessor {
159    config: FFTConfig,
160    planner: FftPlanner<f64>,
161    /// Running average magnitude buffer.
162    avg_buffer: Option<Vec<f64>>,
163    /// Number of frames accumulated so far.
164    avg_count: usize,
165    /// Cached dimensions to detect changes.
166    cached_dims: Vec<usize>,
167}
168
169impl FFTProcessor {
170    pub fn new(mode: FFTMode) -> Self {
171        Self {
172            config: FFTConfig {
173                mode,
174                direction: FFTDirection::Forward,
175                suppress_dc: false,
176                num_average: 0,
177            },
178            planner: FftPlanner::new(),
179            avg_buffer: None,
180            avg_count: 0,
181            cached_dims: Vec::new(),
182        }
183    }
184
185    pub fn with_config(config: FFTConfig) -> Self {
186        Self {
187            config,
188            planner: FftPlanner::new(),
189            avg_buffer: None,
190            avg_count: 0,
191            cached_dims: Vec::new(),
192        }
193    }
194
195    /// Check if dimensions changed and reset averaging state if so.
196    fn check_dims_changed(&mut self, dims: &[NDDimension]) {
197        let current: Vec<usize> = dims.iter().map(|d| d.size).collect();
198        if current != self.cached_dims {
199            self.cached_dims = current;
200            self.avg_buffer = None;
201            self.avg_count = 0;
202        }
203    }
204
205    /// Compute FFT using cached planner for plan reuse across frames.
206    fn compute_fft(&mut self, src: &NDArray) -> Option<NDArray> {
207        let suppress_dc = self.config.suppress_dc;
208
209        match (self.config.mode, self.config.direction) {
210            (FFTMode::Rows1D, FFTDirection::Forward) => {
211                self.compute_fft_1d_rows_forward(src, suppress_dc)
212            }
213            (FFTMode::Rows1D, FFTDirection::Inverse) => {
214                self.compute_fft_1d_rows_inverse(src, suppress_dc)
215            }
216            (FFTMode::Full2D, FFTDirection::Forward) => {
217                self.compute_fft_2d_forward(src, suppress_dc)
218            }
219            (FFTMode::Full2D, FFTDirection::Inverse) => {
220                self.compute_fft_2d_inverse(src, suppress_dc)
221            }
222        }
223    }
224
225    fn compute_fft_1d_rows_forward(
226        &mut self,
227        src: &NDArray,
228        suppress_dc: bool,
229    ) -> Option<NDArray> {
230        if src.dims.is_empty() {
231            return None;
232        }
233
234        let width = src.dims[0].size;
235        let height = if src.dims.len() >= 2 { src.dims[1].size } else { 1 };
236
237        if width == 0 {
238            return None;
239        }
240
241        let fft = self.planner.plan_fft_forward(width);
242
243        let mut magnitudes = vec![0.0f64; width * height];
244        let mut row_buf = vec![Complex::new(0.0, 0.0); width];
245
246        for row in 0..height {
247            for i in 0..width {
248                row_buf[i] = Complex::new(
249                    src.data.get_as_f64(row * width + i).unwrap_or(0.0),
250                    0.0,
251                );
252            }
253            fft.process(&mut row_buf);
254            for (i, c) in row_buf.iter().enumerate() {
255                magnitudes[row * width + i] = c.norm();
256            }
257            if suppress_dc {
258                magnitudes[row * width] = 0.0;
259            }
260        }
261
262        let dims = src.dims.clone();
263        let mut arr = NDArray::new(dims, NDDataType::Float64);
264        arr.data = NDDataBuffer::F64(magnitudes);
265        arr.unique_id = src.unique_id;
266        arr.timestamp = src.timestamp;
267        arr.attributes = src.attributes.clone();
268        Some(arr)
269    }
270
271    fn compute_fft_1d_rows_inverse(
272        &mut self,
273        src: &NDArray,
274        suppress_dc: bool,
275    ) -> Option<NDArray> {
276        if src.dims.is_empty() {
277            return None;
278        }
279
280        let width = src.dims[0].size;
281        let height = if src.dims.len() >= 2 { src.dims[1].size } else { 1 };
282
283        if width == 0 {
284            return None;
285        }
286
287        let fft = self.planner.plan_fft_inverse(width);
288        let scale = 1.0 / width as f64;
289
290        let mut magnitudes = vec![0.0f64; width * height];
291        let mut row_buf = vec![Complex::new(0.0, 0.0); width];
292
293        for row in 0..height {
294            for i in 0..width {
295                row_buf[i] = Complex::new(
296                    src.data.get_as_f64(row * width + i).unwrap_or(0.0),
297                    0.0,
298                );
299            }
300            if suppress_dc {
301                row_buf[0] = Complex::new(0.0, 0.0);
302            }
303            fft.process(&mut row_buf);
304            for (i, c) in row_buf.iter().enumerate() {
305                magnitudes[row * width + i] = c.norm() * scale;
306            }
307        }
308
309        let dims = src.dims.clone();
310        let mut arr = NDArray::new(dims, NDDataType::Float64);
311        arr.data = NDDataBuffer::F64(magnitudes);
312        arr.unique_id = src.unique_id;
313        arr.timestamp = src.timestamp;
314        arr.attributes = src.attributes.clone();
315        Some(arr)
316    }
317
318    fn compute_fft_2d_forward(
319        &mut self,
320        src: &NDArray,
321        suppress_dc: bool,
322    ) -> Option<NDArray> {
323        if src.dims.len() < 2 {
324            return None;
325        }
326
327        let w = src.dims[0].size;
328        let h = src.dims[1].size;
329
330        if w == 0 || h == 0 {
331            return None;
332        }
333
334        let fft_row = self.planner.plan_fft_forward(w);
335        let fft_col = self.planner.plan_fft_forward(h);
336
337        let mut data = vec![Complex::new(0.0, 0.0); w * h];
338        let mut row_buf = vec![Complex::new(0.0, 0.0); w];
339
340        for row in 0..h {
341            for i in 0..w {
342                row_buf[i] = Complex::new(
343                    src.data.get_as_f64(row * w + i).unwrap_or(0.0),
344                    0.0,
345                );
346            }
347            fft_row.process(&mut row_buf);
348            data[row * w..(row * w + w)].copy_from_slice(&row_buf);
349        }
350
351        let mut col_buf = vec![Complex::new(0.0, 0.0); h];
352        for col in 0..w {
353            for row in 0..h {
354                col_buf[row] = data[row * w + col];
355            }
356            fft_col.process(&mut col_buf);
357            for row in 0..h {
358                data[row * w + col] = col_buf[row];
359            }
360        }
361
362        let mut magnitudes: Vec<f64> = data.iter().map(|c| c.norm()).collect();
363
364        if suppress_dc {
365            magnitudes[0] = 0.0;
366        }
367
368        let dims = vec![NDDimension::new(w), NDDimension::new(h)];
369        let mut arr = NDArray::new(dims, NDDataType::Float64);
370        arr.data = NDDataBuffer::F64(magnitudes);
371        arr.unique_id = src.unique_id;
372        arr.timestamp = src.timestamp;
373        arr.attributes = src.attributes.clone();
374        Some(arr)
375    }
376
377    fn compute_fft_2d_inverse(
378        &mut self,
379        src: &NDArray,
380        suppress_dc: bool,
381    ) -> Option<NDArray> {
382        if src.dims.len() < 2 {
383            return None;
384        }
385
386        let w = src.dims[0].size;
387        let h = src.dims[1].size;
388
389        if w == 0 || h == 0 {
390            return None;
391        }
392
393        let fft_row = self.planner.plan_fft_inverse(w);
394        let fft_col = self.planner.plan_fft_inverse(h);
395        let scale = 1.0 / (w * h) as f64;
396
397        let mut data = vec![Complex::new(0.0, 0.0); w * h];
398        for i in 0..w * h {
399            data[i] = Complex::new(
400                src.data.get_as_f64(i).unwrap_or(0.0),
401                0.0,
402            );
403        }
404
405        if suppress_dc {
406            data[0] = Complex::new(0.0, 0.0);
407        }
408
409        let mut col_buf = vec![Complex::new(0.0, 0.0); h];
410        for col in 0..w {
411            for row in 0..h {
412                col_buf[row] = data[row * w + col];
413            }
414            fft_col.process(&mut col_buf);
415            for row in 0..h {
416                data[row * w + col] = col_buf[row];
417            }
418        }
419
420        let mut row_buf = vec![Complex::new(0.0, 0.0); w];
421        for row in 0..h {
422            row_buf.copy_from_slice(&data[row * w..(row * w + w)]);
423            fft_row.process(&mut row_buf);
424            data[row * w..(row * w + w)].copy_from_slice(&row_buf);
425        }
426
427        let magnitudes: Vec<f64> = data.iter().map(|c| c.norm() * scale).collect();
428
429        let dims = vec![NDDimension::new(w), NDDimension::new(h)];
430        let mut arr = NDArray::new(dims, NDDataType::Float64);
431        arr.data = NDDataBuffer::F64(magnitudes);
432        arr.unique_id = src.unique_id;
433        arr.timestamp = src.timestamp;
434        arr.attributes = src.attributes.clone();
435        Some(arr)
436    }
437
438    /// Apply magnitude averaging. Returns the averaged magnitudes if ready.
439    fn apply_averaging(&mut self, magnitudes: &[f64]) -> Vec<f64> {
440        let num_avg = self.config.num_average;
441        if num_avg <= 1 {
442            return magnitudes.to_vec();
443        }
444
445        let buf = self.avg_buffer.get_or_insert_with(|| vec![0.0; magnitudes.len()]);
446
447        // Reset if buffer size changed (shouldn't happen after check_dims_changed, but guard)
448        if buf.len() != magnitudes.len() {
449            *buf = vec![0.0; magnitudes.len()];
450            self.avg_count = 0;
451        }
452
453        // Accumulate
454        for (b, &m) in buf.iter_mut().zip(magnitudes.iter()) {
455            *b += m;
456        }
457        self.avg_count += 1;
458
459        if self.avg_count >= num_avg {
460            // Output averaged result and reset
461            let result: Vec<f64> = buf.iter().map(|&v| v / self.avg_count as f64).collect();
462            buf.iter_mut().for_each(|v| *v = 0.0);
463            self.avg_count = 0;
464            result
465        } else {
466            // Not enough frames yet — return current partial average
467            buf.iter().map(|&v| v / self.avg_count as f64).collect()
468        }
469    }
470}
471
472impl NDPluginProcess for FFTProcessor {
473    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
474        self.check_dims_changed(&array.dims);
475
476        let result = self.compute_fft(array);
477        match result {
478            Some(mut out) => {
479                if self.config.num_average > 1 {
480                    if let NDDataBuffer::F64(ref mags) = out.data {
481                        let averaged = self.apply_averaging(mags);
482                        out.data = NDDataBuffer::F64(averaged);
483                    }
484                }
485                ProcessResult::arrays(vec![Arc::new(out)])
486            }
487            None => ProcessResult::empty(),
488        }
489    }
490
491    fn plugin_type(&self) -> &str {
492        "NDPluginFFT"
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_fft_1d_dc() {
502        // Constant signal: DC component should dominate
503        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
504        if let NDDataBuffer::F64(ref mut v) = arr.data {
505            for i in 0..8 {
506                v[i] = 1.0;
507            }
508        }
509
510        let result = fft_1d_rows(&arr, false).unwrap();
511        if let NDDataBuffer::F64(ref v) = result.data {
512            // DC component (k=0) should be 8.0
513            assert!((v[0] - 8.0).abs() < 1e-10);
514            // Other components should be ~0
515            assert!(v[1].abs() < 1e-10);
516        }
517    }
518
519    #[test]
520    fn test_fft_1d_sine() {
521        // Sine wave at frequency 1: peak at k=1 and k=N-1
522        let n = 16;
523        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
524        if let NDDataBuffer::F64(ref mut v) = arr.data {
525            for i in 0..n {
526                v[i] = (2.0 * std::f64::consts::PI * i as f64 / n as f64).sin();
527            }
528        }
529
530        let result = fft_1d_rows(&arr, false).unwrap();
531        if let NDDataBuffer::F64(ref v) = result.data {
532            // DC should be ~0
533            assert!(v[0].abs() < 1e-10);
534            // Peak at k=1
535            assert!(v[1] > 7.0);
536            // k=2 should be small
537            assert!(v[2].abs() < 1e-10);
538        }
539    }
540
541    #[test]
542    fn test_fft_2d_dimensions() {
543        let arr = NDArray::new(
544            vec![NDDimension::new(4), NDDimension::new(4)],
545            NDDataType::UInt8,
546        );
547        let result = fft_2d(&arr, false).unwrap();
548        assert_eq!(result.dims[0].size, 4);
549        assert_eq!(result.dims[1].size, 4);
550        assert_eq!(result.data.data_type(), NDDataType::Float64);
551    }
552
553    #[test]
554    fn test_fft_1d_suppress_dc() {
555        // Constant signal: DC component should be suppressed
556        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
557        if let NDDataBuffer::F64(ref mut v) = arr.data {
558            for i in 0..8 {
559                v[i] = 1.0;
560            }
561        }
562
563        let result = fft_1d_rows(&arr, true).unwrap();
564        if let NDDataBuffer::F64(ref v) = result.data {
565            // DC component should be zeroed out
566            assert!((v[0]).abs() < 1e-15);
567            // Other components should still be ~0 for constant signal
568            assert!(v[1].abs() < 1e-10);
569        } else {
570            panic!("expected F64 data");
571        }
572    }
573
574    #[test]
575    fn test_fft_2d_suppress_dc() {
576        // 4x4 constant array, suppress_dc should zero out [0,0]
577        let mut arr = NDArray::new(
578            vec![NDDimension::new(4), NDDimension::new(4)],
579            NDDataType::Float64,
580        );
581        if let NDDataBuffer::F64(ref mut v) = arr.data {
582            for val in v.iter_mut() {
583                *val = 3.0;
584            }
585        }
586
587        let result = fft_2d(&arr, true).unwrap();
588        if let NDDataBuffer::F64(ref v) = result.data {
589            // DC at [0,0] should be zeroed
590            assert!((v[0]).abs() < 1e-15);
591        } else {
592            panic!("expected F64 data");
593        }
594    }
595
596    #[test]
597    fn test_fft_2d_known_dc() {
598        // 4x4 constant=2.0 => DC should be 4*4*2 = 32
599        let mut arr = NDArray::new(
600            vec![NDDimension::new(4), NDDimension::new(4)],
601            NDDataType::Float64,
602        );
603        if let NDDataBuffer::F64(ref mut v) = arr.data {
604            for val in v.iter_mut() {
605                *val = 2.0;
606            }
607        }
608
609        let result = fft_2d(&arr, false).unwrap();
610        if let NDDataBuffer::F64(ref v) = result.data {
611            assert!((v[0] - 32.0).abs() < 1e-10, "DC = {}, expected 32", v[0]);
612            // All other bins should be ~0
613            for i in 1..v.len() {
614                assert!(v[i].abs() < 1e-10, "bin {} = {}, expected ~0", i, v[i]);
615            }
616        } else {
617            panic!("expected F64 data");
618        }
619    }
620
621    #[test]
622    fn test_fft_1d_known_cosine_peaks() {
623        // Cosine at frequency 3 in N=16: peaks at k=3 and k=N-3=13
624        let n = 16;
625        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
626        if let NDDataBuffer::F64(ref mut v) = arr.data {
627            for i in 0..n {
628                v[i] = (2.0 * std::f64::consts::PI * 3.0 * i as f64 / n as f64).cos();
629            }
630        }
631
632        let result = fft_1d_rows(&arr, false).unwrap();
633        if let NDDataBuffer::F64(ref v) = result.data {
634            // DC should be ~0
635            assert!(v[0].abs() < 1e-10);
636            // k=3 and k=13 should have magnitude N/2 = 8
637            assert!(
638                (v[3] - 8.0).abs() < 1e-10,
639                "k=3 magnitude = {}, expected 8",
640                v[3]
641            );
642            assert!(
643                (v[13] - 8.0).abs() < 1e-10,
644                "k=13 magnitude = {}, expected 8",
645                v[13]
646            );
647            // Other bins should be ~0
648            for k in [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15] {
649                assert!(v[k].abs() < 1e-10, "k={} magnitude = {}, expected ~0", k, v[k]);
650            }
651        } else {
652            panic!("expected F64 data");
653        }
654    }
655
656    #[test]
657    fn test_processor_with_config() {
658        let config = FFTConfig {
659            mode: FFTMode::Rows1D,
660            direction: FFTDirection::Forward,
661            suppress_dc: true,
662            num_average: 0,
663        };
664        let mut proc = FFTProcessor::with_config(config);
665        let pool = NDArrayPool::new(0);
666
667        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
668        if let NDDataBuffer::F64(ref mut v) = arr.data {
669            for i in 0..8 {
670                v[i] = 5.0;
671            }
672        }
673
674        let result = proc.process_array(&arr, &pool);
675        assert_eq!(result.output_arrays.len(), 1);
676        if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
677            // suppress_dc: DC should be 0
678            assert!(v[0].abs() < 1e-15);
679        } else {
680            panic!("expected F64 data");
681        }
682    }
683
684    #[test]
685    fn test_processor_averaging() {
686        let config = FFTConfig {
687            mode: FFTMode::Rows1D,
688            direction: FFTDirection::Forward,
689            suppress_dc: false,
690            num_average: 2,
691        };
692        let mut proc = FFTProcessor::with_config(config);
693        let pool = NDArrayPool::new(0);
694
695        // Frame 1: constant = 2.0 => DC magnitude = 8*2 = 16
696        let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
697        if let NDDataBuffer::F64(ref mut v) = arr1.data {
698            for i in 0..8 {
699                v[i] = 2.0;
700            }
701        }
702
703        // Frame 2: constant = 4.0 => DC magnitude = 8*4 = 32
704        let mut arr2 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
705        if let NDDataBuffer::F64(ref mut v) = arr2.data {
706            for i in 0..8 {
707                v[i] = 4.0;
708            }
709        }
710
711        let r1 = proc.process_array(&arr1, &pool);
712        assert_eq!(r1.output_arrays.len(), 1);
713        // After 1 frame with num_average=2, partial average = 16/1 = 16
714        if let NDDataBuffer::F64(ref v) = r1.output_arrays[0].data {
715            assert!((v[0] - 16.0).abs() < 1e-10, "partial avg DC = {}", v[0]);
716        }
717
718        let r2 = proc.process_array(&arr2, &pool);
719        assert_eq!(r2.output_arrays.len(), 1);
720        // After 2 frames: average = (16+32)/2 = 24, then reset
721        if let NDDataBuffer::F64(ref v) = r2.output_arrays[0].data {
722            assert!((v[0] - 24.0).abs() < 1e-10, "averaged DC = {}", v[0]);
723        }
724    }
725
726    #[test]
727    fn test_processor_averaging_dimension_change_resets() {
728        let config = FFTConfig {
729            mode: FFTMode::Rows1D,
730            direction: FFTDirection::Forward,
731            suppress_dc: false,
732            num_average: 3,
733        };
734        let mut proc = FFTProcessor::with_config(config);
735        let pool = NDArrayPool::new(0);
736
737        // Frame 1: width=8
738        let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
739        if let NDDataBuffer::F64(ref mut v) = arr1.data {
740            for i in 0..8 {
741                v[i] = 1.0;
742            }
743        }
744        let _ = proc.process_array(&arr1, &pool);
745        assert_eq!(proc.avg_count, 1);
746
747        // Frame 2: width=4 — dimension change should reset
748        let mut arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
749        if let NDDataBuffer::F64(ref mut v) = arr2.data {
750            for i in 0..4 {
751                v[i] = 1.0;
752            }
753        }
754        let _ = proc.process_array(&arr2, &pool);
755        // After dimension change, avg_count should be 1 (reset + one new frame)
756        assert_eq!(proc.avg_count, 1);
757    }
758
759    #[test]
760    fn test_fft_1d_multirow() {
761        // 2 rows, each a different constant
762        let w = 4;
763        let h = 2;
764        let mut arr = NDArray::new(
765            vec![NDDimension::new(w), NDDimension::new(h)],
766            NDDataType::Float64,
767        );
768        if let NDDataBuffer::F64(ref mut v) = arr.data {
769            // Row 0: all 1.0
770            for i in 0..w {
771                v[i] = 1.0;
772            }
773            // Row 1: all 3.0
774            for i in w..2 * w {
775                v[i] = 3.0;
776            }
777        }
778
779        let result = fft_1d_rows(&arr, false).unwrap();
780        if let NDDataBuffer::F64(ref v) = result.data {
781            // Row 0 DC = 4*1 = 4
782            assert!((v[0] - 4.0).abs() < 1e-10);
783            // Row 1 DC = 4*3 = 12
784            assert!((v[w] - 12.0).abs() < 1e-10);
785        } else {
786            panic!("expected F64 data");
787        }
788    }
789
790    #[test]
791    fn test_inverse_fft_1d() {
792        // IFFT of a known forward FFT should give back the original magnitudes
793        // For a real constant signal, forward FFT gives [N, 0, 0, ...0]
794        // IFFT of [N, 0, ...0] (real input) should give constant = 1.0 for each sample
795        let n = 8;
796        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
797        if let NDDataBuffer::F64(ref mut v) = arr.data {
798            v[0] = 8.0; // DC = N
799            // rest are 0
800        }
801
802        let config = FFTConfig {
803            mode: FFTMode::Rows1D,
804            direction: FFTDirection::Inverse,
805            suppress_dc: false,
806            num_average: 0,
807        };
808        let mut proc = FFTProcessor::with_config(config);
809        let pool = NDArrayPool::new(0);
810
811        let result = proc.process_array(&arr, &pool);
812        assert_eq!(result.output_arrays.len(), 1);
813        if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
814            // Each sample should be magnitude 1.0 (8/8 = 1.0 after normalization)
815            for i in 0..n {
816                assert!(
817                    (v[i] - 1.0).abs() < 1e-10,
818                    "sample {} = {}, expected 1.0",
819                    i,
820                    v[i]
821                );
822            }
823        } else {
824            panic!("expected F64 data");
825        }
826    }
827
828    #[test]
829    fn test_fft_preserves_metadata() {
830        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
831        arr.unique_id = 42;
832        if let NDDataBuffer::F64(ref mut v) = arr.data {
833            v[0] = 1.0;
834        }
835
836        let result = fft_1d_rows(&arr, false).unwrap();
837        assert_eq!(result.unique_id, 42);
838        assert_eq!(result.timestamp, arr.timestamp);
839    }
840}