Skip to main content

ad_plugins_rs/
fft.rs

1use std::sync::Arc;
2
3use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6use rustfft::FftPlanner;
7use rustfft::num_complex::Complex;
8
9/// FFT mode selection.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum FFTMode {
12    Rows1D,
13    Full2D,
14}
15
16/// FFT direction (forward or inverse transform).
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum FFTDirection {
19    Forward,
20    Inverse,
21}
22
23/// Configuration for FFT processing.
24pub struct FFTConfig {
25    pub mode: FFTMode,
26    pub direction: FFTDirection,
27    /// Zero out DC component (k=0) in the output magnitudes.
28    pub suppress_dc: bool,
29    /// Average N frames of magnitude. 0 or 1 means no averaging.
30    pub num_average: usize,
31}
32
33impl Default for FFTConfig {
34    fn default() -> Self {
35        Self {
36            mode: FFTMode::Rows1D,
37            direction: FFTDirection::Forward,
38            suppress_dc: false,
39            num_average: 0,
40        }
41    }
42}
43
44/// Compute 1D FFT magnitude for each row of a 2D array using rustfft.
45/// Returns a Float64 array with the same dimensions.
46pub fn fft_1d_rows(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
47    if src.dims.is_empty() {
48        return None;
49    }
50
51    let width = src.dims[0].size;
52    let height = if src.dims.len() >= 2 {
53        src.dims[1].size
54    } else {
55        1
56    };
57
58    if width == 0 {
59        return None;
60    }
61
62    let mut planner = FftPlanner::<f64>::new();
63    let fft = planner.plan_fft_forward(width);
64
65    let mut magnitudes = vec![0.0f64; width * height];
66    let mut row_buf = vec![Complex::new(0.0, 0.0); width];
67
68    for row in 0..height {
69        // Fill complex buffer: real = pixel value, imag = 0
70        for i in 0..width {
71            row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
72        }
73
74        fft.process(&mut row_buf);
75
76        // Compute magnitudes
77        for (i, c) in row_buf.iter().enumerate() {
78            magnitudes[row * width + i] = c.norm();
79        }
80
81        if suppress_dc {
82            magnitudes[row * width] = 0.0;
83        }
84    }
85
86    let dims = src.dims.clone();
87    let mut arr = NDArray::new(dims, NDDataType::Float64);
88    arr.data = NDDataBuffer::F64(magnitudes);
89    arr.unique_id = src.unique_id;
90    arr.timestamp = src.timestamp;
91    arr.attributes = src.attributes.clone();
92    Some(arr)
93}
94
95/// Compute 2D FFT magnitude using separable row-then-column FFT via rustfft.
96pub fn fft_2d(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
97    if src.dims.len() < 2 {
98        return None;
99    }
100
101    let w = src.dims[0].size;
102    let h = src.dims[1].size;
103
104    if w == 0 || h == 0 {
105        return None;
106    }
107
108    let mut planner = FftPlanner::<f64>::new();
109    let fft_row = planner.plan_fft_forward(w);
110    let fft_col = planner.plan_fft_forward(h);
111
112    // Step 1: Row FFTs — build a w×h complex buffer
113    let mut data = vec![Complex::new(0.0, 0.0); w * h];
114    let mut row_buf = vec![Complex::new(0.0, 0.0); w];
115
116    for row in 0..h {
117        for i in 0..w {
118            row_buf[i] = Complex::new(src.data.get_as_f64(row * w + i).unwrap_or(0.0), 0.0);
119        }
120        fft_row.process(&mut row_buf);
121        data[row * w..(row * w + w)].copy_from_slice(&row_buf);
122    }
123
124    // Step 2: Column FFTs
125    let mut col_buf = vec![Complex::new(0.0, 0.0); h];
126
127    for col in 0..w {
128        // Extract column
129        for row in 0..h {
130            col_buf[row] = data[row * w + col];
131        }
132        fft_col.process(&mut col_buf);
133        // Write back
134        for row in 0..h {
135            data[row * w + col] = col_buf[row];
136        }
137    }
138
139    // Step 3: Compute magnitudes
140    let mut magnitudes: Vec<f64> = data.iter().map(|c| c.norm()).collect();
141
142    if suppress_dc {
143        magnitudes[0] = 0.0;
144    }
145
146    let dims = vec![NDDimension::new(w), NDDimension::new(h)];
147    let mut arr = NDArray::new(dims, NDDataType::Float64);
148    arr.data = NDDataBuffer::F64(magnitudes);
149    arr.unique_id = src.unique_id;
150    arr.timestamp = src.timestamp;
151    arr.attributes = src.attributes.clone();
152    Some(arr)
153}
154
155/// FFT processing engine with cached planner and optional magnitude averaging.
156#[derive(Default)]
157struct FFTParamIndices {
158    direction: Option<usize>,
159    suppress_dc: Option<usize>,
160    num_average: Option<usize>,
161    num_averaged: Option<usize>,
162    reset_average: Option<usize>,
163}
164
165pub struct FFTProcessor {
166    config: FFTConfig,
167    planner: FftPlanner<f64>,
168    /// Running average magnitude buffer.
169    avg_buffer: Option<Vec<f64>>,
170    /// Number of frames accumulated so far.
171    avg_count: usize,
172    /// Cached dimensions to detect changes.
173    cached_dims: Vec<usize>,
174    params: FFTParamIndices,
175}
176
177impl FFTProcessor {
178    pub fn new(mode: FFTMode) -> Self {
179        Self {
180            config: FFTConfig {
181                mode,
182                direction: FFTDirection::Forward,
183                suppress_dc: false,
184                num_average: 0,
185            },
186            planner: FftPlanner::new(),
187            avg_buffer: None,
188            avg_count: 0,
189            cached_dims: Vec::new(),
190            params: FFTParamIndices::default(),
191        }
192    }
193
194    pub fn with_config(config: FFTConfig) -> Self {
195        Self {
196            config,
197            planner: FftPlanner::new(),
198            avg_buffer: None,
199            avg_count: 0,
200            cached_dims: Vec::new(),
201            params: FFTParamIndices::default(),
202        }
203    }
204
205    /// Check if dimensions changed and reset averaging state if so.
206    fn check_dims_changed(&mut self, dims: &[NDDimension]) {
207        let current: Vec<usize> = dims.iter().map(|d| d.size).collect();
208        if current != self.cached_dims {
209            self.cached_dims = current;
210            self.avg_buffer = None;
211            self.avg_count = 0;
212        }
213    }
214
215    /// Compute FFT using cached planner for plan reuse across frames.
216    fn compute_fft(&mut self, src: &NDArray) -> Option<NDArray> {
217        let suppress_dc = self.config.suppress_dc;
218
219        match (self.config.mode, self.config.direction) {
220            (FFTMode::Rows1D, FFTDirection::Forward) => {
221                self.compute_fft_1d_rows_forward(src, suppress_dc)
222            }
223            (FFTMode::Rows1D, FFTDirection::Inverse) => {
224                self.compute_fft_1d_rows_inverse(src, suppress_dc)
225            }
226            (FFTMode::Full2D, FFTDirection::Forward) => {
227                self.compute_fft_2d_forward(src, suppress_dc)
228            }
229            (FFTMode::Full2D, FFTDirection::Inverse) => {
230                self.compute_fft_2d_inverse(src, suppress_dc)
231            }
232        }
233    }
234
235    fn compute_fft_1d_rows_forward(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
236        if src.dims.is_empty() {
237            return None;
238        }
239
240        let width = src.dims[0].size;
241        let height = if src.dims.len() >= 2 {
242            src.dims[1].size
243        } else {
244            1
245        };
246
247        if width == 0 {
248            return None;
249        }
250
251        let fft = self.planner.plan_fft_forward(width);
252
253        let mut magnitudes = vec![0.0f64; width * height];
254        let mut row_buf = vec![Complex::new(0.0, 0.0); width];
255
256        for row in 0..height {
257            for i in 0..width {
258                row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
259            }
260            fft.process(&mut row_buf);
261            for (i, c) in row_buf.iter().enumerate() {
262                magnitudes[row * width + i] = c.norm();
263            }
264            if suppress_dc {
265                magnitudes[row * width] = 0.0;
266            }
267        }
268
269        let dims = src.dims.clone();
270        let mut arr = NDArray::new(dims, NDDataType::Float64);
271        arr.data = NDDataBuffer::F64(magnitudes);
272        arr.unique_id = src.unique_id;
273        arr.timestamp = src.timestamp;
274        arr.attributes = src.attributes.clone();
275        Some(arr)
276    }
277
278    fn compute_fft_1d_rows_inverse(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
279        if src.dims.is_empty() {
280            return None;
281        }
282
283        let width = src.dims[0].size;
284        let height = if src.dims.len() >= 2 {
285            src.dims[1].size
286        } else {
287            1
288        };
289
290        if width == 0 {
291            return None;
292        }
293
294        let fft = self.planner.plan_fft_inverse(width);
295        let scale = 1.0 / width as f64;
296
297        let mut magnitudes = vec![0.0f64; width * height];
298        let mut row_buf = vec![Complex::new(0.0, 0.0); width];
299
300        for row in 0..height {
301            for i in 0..width {
302                row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
303            }
304            if suppress_dc {
305                row_buf[0] = Complex::new(0.0, 0.0);
306            }
307            fft.process(&mut row_buf);
308            for (i, c) in row_buf.iter().enumerate() {
309                magnitudes[row * width + i] = c.norm() * scale;
310            }
311        }
312
313        let dims = src.dims.clone();
314        let mut arr = NDArray::new(dims, NDDataType::Float64);
315        arr.data = NDDataBuffer::F64(magnitudes);
316        arr.unique_id = src.unique_id;
317        arr.timestamp = src.timestamp;
318        arr.attributes = src.attributes.clone();
319        Some(arr)
320    }
321
322    fn compute_fft_2d_forward(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
323        if src.dims.len() < 2 {
324            return None;
325        }
326
327        let w = src.dims[0].size;
328        let h = src.dims[1].size;
329
330        if w == 0 || h == 0 {
331            return None;
332        }
333
334        let fft_row = self.planner.plan_fft_forward(w);
335        let fft_col = self.planner.plan_fft_forward(h);
336
337        let mut data = vec![Complex::new(0.0, 0.0); w * h];
338        let mut row_buf = vec![Complex::new(0.0, 0.0); w];
339
340        for row in 0..h {
341            for i in 0..w {
342                row_buf[i] = Complex::new(src.data.get_as_f64(row * w + i).unwrap_or(0.0), 0.0);
343            }
344            fft_row.process(&mut row_buf);
345            data[row * w..(row * w + w)].copy_from_slice(&row_buf);
346        }
347
348        let mut col_buf = vec![Complex::new(0.0, 0.0); h];
349        for col in 0..w {
350            for row in 0..h {
351                col_buf[row] = data[row * w + col];
352            }
353            fft_col.process(&mut col_buf);
354            for row in 0..h {
355                data[row * w + col] = col_buf[row];
356            }
357        }
358
359        let mut magnitudes: Vec<f64> = data.iter().map(|c| c.norm()).collect();
360
361        if suppress_dc {
362            magnitudes[0] = 0.0;
363        }
364
365        let dims = vec![NDDimension::new(w), NDDimension::new(h)];
366        let mut arr = NDArray::new(dims, NDDataType::Float64);
367        arr.data = NDDataBuffer::F64(magnitudes);
368        arr.unique_id = src.unique_id;
369        arr.timestamp = src.timestamp;
370        arr.attributes = src.attributes.clone();
371        Some(arr)
372    }
373
374    fn compute_fft_2d_inverse(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
375        if src.dims.len() < 2 {
376            return None;
377        }
378
379        let w = src.dims[0].size;
380        let h = src.dims[1].size;
381
382        if w == 0 || h == 0 {
383            return None;
384        }
385
386        let fft_row = self.planner.plan_fft_inverse(w);
387        let fft_col = self.planner.plan_fft_inverse(h);
388        let scale = 1.0 / (w * h) as f64;
389
390        let mut data = vec![Complex::new(0.0, 0.0); w * h];
391        for i in 0..w * h {
392            data[i] = Complex::new(src.data.get_as_f64(i).unwrap_or(0.0), 0.0);
393        }
394
395        if suppress_dc {
396            data[0] = Complex::new(0.0, 0.0);
397        }
398
399        let mut col_buf = vec![Complex::new(0.0, 0.0); h];
400        for col in 0..w {
401            for row in 0..h {
402                col_buf[row] = data[row * w + col];
403            }
404            fft_col.process(&mut col_buf);
405            for row in 0..h {
406                data[row * w + col] = col_buf[row];
407            }
408        }
409
410        let mut row_buf = vec![Complex::new(0.0, 0.0); w];
411        for row in 0..h {
412            row_buf.copy_from_slice(&data[row * w..(row * w + w)]);
413            fft_row.process(&mut row_buf);
414            data[row * w..(row * w + w)].copy_from_slice(&row_buf);
415        }
416
417        let magnitudes: Vec<f64> = data.iter().map(|c| c.norm() * scale).collect();
418
419        let dims = vec![NDDimension::new(w), NDDimension::new(h)];
420        let mut arr = NDArray::new(dims, NDDataType::Float64);
421        arr.data = NDDataBuffer::F64(magnitudes);
422        arr.unique_id = src.unique_id;
423        arr.timestamp = src.timestamp;
424        arr.attributes = src.attributes.clone();
425        Some(arr)
426    }
427
428    /// Apply magnitude averaging. Returns the averaged magnitudes if ready.
429    fn apply_averaging(&mut self, magnitudes: &[f64]) -> Vec<f64> {
430        let num_avg = self.config.num_average;
431        if num_avg <= 1 {
432            return magnitudes.to_vec();
433        }
434
435        let buf = self
436            .avg_buffer
437            .get_or_insert_with(|| vec![0.0; magnitudes.len()]);
438
439        // Reset if buffer size changed (shouldn't happen after check_dims_changed, but guard)
440        if buf.len() != magnitudes.len() {
441            *buf = vec![0.0; magnitudes.len()];
442            self.avg_count = 0;
443        }
444
445        // Accumulate
446        for (b, &m) in buf.iter_mut().zip(magnitudes.iter()) {
447            *b += m;
448        }
449        self.avg_count += 1;
450
451        if self.avg_count >= num_avg {
452            // Output averaged result and reset
453            let result: Vec<f64> = buf.iter().map(|&v| v / self.avg_count as f64).collect();
454            buf.iter_mut().for_each(|v| *v = 0.0);
455            self.avg_count = 0;
456            result
457        } else {
458            // Not enough frames yet — return current partial average
459            buf.iter().map(|&v| v / self.avg_count as f64).collect()
460        }
461    }
462}
463
464impl NDPluginProcess for FFTProcessor {
465    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
466        use ad_core_rs::plugin::runtime::ParamUpdate;
467
468        self.check_dims_changed(&array.dims);
469
470        let result = self.compute_fft(array);
471        let mut updates = Vec::new();
472        if let Some(idx) = self.params.num_averaged {
473            updates.push(ParamUpdate::int32(idx, self.avg_count as i32));
474        }
475
476        match result {
477            Some(mut out) => {
478                if self.config.num_average > 1 {
479                    if let NDDataBuffer::F64(ref mags) = out.data {
480                        let averaged = self.apply_averaging(mags);
481                        out.data = NDDataBuffer::F64(averaged);
482                    }
483                }
484                let mut r = ProcessResult::arrays(vec![Arc::new(out)]);
485                r.param_updates = updates;
486                r
487            }
488            None => ProcessResult::sink(updates),
489        }
490    }
491
492    fn plugin_type(&self) -> &str {
493        "NDPluginFFT"
494    }
495
496    fn register_params(
497        &mut self,
498        base: &mut asyn_rs::port::PortDriverBase,
499    ) -> asyn_rs::error::AsynResult<()> {
500        use asyn_rs::param::ParamType;
501        base.create_param("FFT_TIME_PER_POINT", ParamType::Float64)?;
502        base.create_param("FFT_TIME_AXIS", ParamType::Float64Array)?;
503        base.create_param("FFT_FREQ_AXIS", ParamType::Float64Array)?;
504        base.create_param("FFT_DIRECTION", ParamType::Int32)?;
505        base.create_param("FFT_SUPPRESS_DC", ParamType::Int32)?;
506        base.create_param("FFT_NUM_AVERAGE", ParamType::Int32)?;
507        base.create_param("FFT_NUM_AVERAGED", ParamType::Int32)?;
508        base.create_param("FFT_RESET_AVERAGE", ParamType::Int32)?;
509        base.create_param("FFT_TIME_SERIES", ParamType::Float64Array)?;
510        base.create_param("FFT_REAL", ParamType::Float64Array)?;
511        base.create_param("FFT_IMAGINARY", ParamType::Float64Array)?;
512        base.create_param("FFT_ABS_VALUE", ParamType::Float64Array)?;
513
514        self.params.direction = base.find_param("FFT_DIRECTION");
515        self.params.suppress_dc = base.find_param("FFT_SUPPRESS_DC");
516        self.params.num_average = base.find_param("FFT_NUM_AVERAGE");
517        self.params.num_averaged = base.find_param("FFT_NUM_AVERAGED");
518        self.params.reset_average = base.find_param("FFT_RESET_AVERAGE");
519        Ok(())
520    }
521
522    fn on_param_change(
523        &mut self,
524        reason: usize,
525        params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
526    ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
527        if Some(reason) == self.params.direction {
528            self.config.direction = if params.value.as_i32() == 0 {
529                FFTDirection::Forward
530            } else {
531                FFTDirection::Inverse
532            };
533        } else if Some(reason) == self.params.suppress_dc {
534            self.config.suppress_dc = params.value.as_i32() != 0;
535        } else if Some(reason) == self.params.num_average {
536            self.config.num_average = params.value.as_i32().max(0) as usize;
537        } else if Some(reason) == self.params.reset_average {
538            if params.value.as_i32() != 0 {
539                self.avg_buffer = None;
540                self.avg_count = 0;
541            }
542        }
543        ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn test_fft_1d_dc() {
553        // Constant signal: DC component should dominate
554        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
555        if let NDDataBuffer::F64(ref mut v) = arr.data {
556            for i in 0..8 {
557                v[i] = 1.0;
558            }
559        }
560
561        let result = fft_1d_rows(&arr, false).unwrap();
562        if let NDDataBuffer::F64(ref v) = result.data {
563            // DC component (k=0) should be 8.0
564            assert!((v[0] - 8.0).abs() < 1e-10);
565            // Other components should be ~0
566            assert!(v[1].abs() < 1e-10);
567        }
568    }
569
570    #[test]
571    fn test_fft_1d_sine() {
572        // Sine wave at frequency 1: peak at k=1 and k=N-1
573        let n = 16;
574        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
575        if let NDDataBuffer::F64(ref mut v) = arr.data {
576            for i in 0..n {
577                v[i] = (2.0 * std::f64::consts::PI * i as f64 / n as f64).sin();
578            }
579        }
580
581        let result = fft_1d_rows(&arr, false).unwrap();
582        if let NDDataBuffer::F64(ref v) = result.data {
583            // DC should be ~0
584            assert!(v[0].abs() < 1e-10);
585            // Peak at k=1
586            assert!(v[1] > 7.0);
587            // k=2 should be small
588            assert!(v[2].abs() < 1e-10);
589        }
590    }
591
592    #[test]
593    fn test_fft_2d_dimensions() {
594        let arr = NDArray::new(
595            vec![NDDimension::new(4), NDDimension::new(4)],
596            NDDataType::UInt8,
597        );
598        let result = fft_2d(&arr, false).unwrap();
599        assert_eq!(result.dims[0].size, 4);
600        assert_eq!(result.dims[1].size, 4);
601        assert_eq!(result.data.data_type(), NDDataType::Float64);
602    }
603
604    #[test]
605    fn test_fft_1d_suppress_dc() {
606        // Constant signal: DC component should be suppressed
607        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
608        if let NDDataBuffer::F64(ref mut v) = arr.data {
609            for i in 0..8 {
610                v[i] = 1.0;
611            }
612        }
613
614        let result = fft_1d_rows(&arr, true).unwrap();
615        if let NDDataBuffer::F64(ref v) = result.data {
616            // DC component should be zeroed out
617            assert!((v[0]).abs() < 1e-15);
618            // Other components should still be ~0 for constant signal
619            assert!(v[1].abs() < 1e-10);
620        } else {
621            panic!("expected F64 data");
622        }
623    }
624
625    #[test]
626    fn test_fft_2d_suppress_dc() {
627        // 4x4 constant array, suppress_dc should zero out [0,0]
628        let mut arr = NDArray::new(
629            vec![NDDimension::new(4), NDDimension::new(4)],
630            NDDataType::Float64,
631        );
632        if let NDDataBuffer::F64(ref mut v) = arr.data {
633            for val in v.iter_mut() {
634                *val = 3.0;
635            }
636        }
637
638        let result = fft_2d(&arr, true).unwrap();
639        if let NDDataBuffer::F64(ref v) = result.data {
640            // DC at [0,0] should be zeroed
641            assert!((v[0]).abs() < 1e-15);
642        } else {
643            panic!("expected F64 data");
644        }
645    }
646
647    #[test]
648    fn test_fft_2d_known_dc() {
649        // 4x4 constant=2.0 => DC should be 4*4*2 = 32
650        let mut arr = NDArray::new(
651            vec![NDDimension::new(4), NDDimension::new(4)],
652            NDDataType::Float64,
653        );
654        if let NDDataBuffer::F64(ref mut v) = arr.data {
655            for val in v.iter_mut() {
656                *val = 2.0;
657            }
658        }
659
660        let result = fft_2d(&arr, false).unwrap();
661        if let NDDataBuffer::F64(ref v) = result.data {
662            assert!((v[0] - 32.0).abs() < 1e-10, "DC = {}, expected 32", v[0]);
663            // All other bins should be ~0
664            for i in 1..v.len() {
665                assert!(v[i].abs() < 1e-10, "bin {} = {}, expected ~0", i, v[i]);
666            }
667        } else {
668            panic!("expected F64 data");
669        }
670    }
671
672    #[test]
673    fn test_fft_1d_known_cosine_peaks() {
674        // Cosine at frequency 3 in N=16: peaks at k=3 and k=N-3=13
675        let n = 16;
676        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
677        if let NDDataBuffer::F64(ref mut v) = arr.data {
678            for i in 0..n {
679                v[i] = (2.0 * std::f64::consts::PI * 3.0 * i as f64 / n as f64).cos();
680            }
681        }
682
683        let result = fft_1d_rows(&arr, false).unwrap();
684        if let NDDataBuffer::F64(ref v) = result.data {
685            // DC should be ~0
686            assert!(v[0].abs() < 1e-10);
687            // k=3 and k=13 should have magnitude N/2 = 8
688            assert!(
689                (v[3] - 8.0).abs() < 1e-10,
690                "k=3 magnitude = {}, expected 8",
691                v[3]
692            );
693            assert!(
694                (v[13] - 8.0).abs() < 1e-10,
695                "k=13 magnitude = {}, expected 8",
696                v[13]
697            );
698            // Other bins should be ~0
699            for k in [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15] {
700                assert!(
701                    v[k].abs() < 1e-10,
702                    "k={} magnitude = {}, expected ~0",
703                    k,
704                    v[k]
705                );
706            }
707        } else {
708            panic!("expected F64 data");
709        }
710    }
711
712    #[test]
713    fn test_processor_with_config() {
714        let config = FFTConfig {
715            mode: FFTMode::Rows1D,
716            direction: FFTDirection::Forward,
717            suppress_dc: true,
718            num_average: 0,
719        };
720        let mut proc = FFTProcessor::with_config(config);
721        let pool = NDArrayPool::new(0);
722
723        let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
724        if let NDDataBuffer::F64(ref mut v) = arr.data {
725            for i in 0..8 {
726                v[i] = 5.0;
727            }
728        }
729
730        let result = proc.process_array(&arr, &pool);
731        assert_eq!(result.output_arrays.len(), 1);
732        if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
733            // suppress_dc: DC should be 0
734            assert!(v[0].abs() < 1e-15);
735        } else {
736            panic!("expected F64 data");
737        }
738    }
739
740    #[test]
741    fn test_processor_averaging() {
742        let config = FFTConfig {
743            mode: FFTMode::Rows1D,
744            direction: FFTDirection::Forward,
745            suppress_dc: false,
746            num_average: 2,
747        };
748        let mut proc = FFTProcessor::with_config(config);
749        let pool = NDArrayPool::new(0);
750
751        // Frame 1: constant = 2.0 => DC magnitude = 8*2 = 16
752        let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
753        if let NDDataBuffer::F64(ref mut v) = arr1.data {
754            for i in 0..8 {
755                v[i] = 2.0;
756            }
757        }
758
759        // Frame 2: constant = 4.0 => DC magnitude = 8*4 = 32
760        let mut arr2 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
761        if let NDDataBuffer::F64(ref mut v) = arr2.data {
762            for i in 0..8 {
763                v[i] = 4.0;
764            }
765        }
766
767        let r1 = proc.process_array(&arr1, &pool);
768        assert_eq!(r1.output_arrays.len(), 1);
769        // After 1 frame with num_average=2, partial average = 16/1 = 16
770        if let NDDataBuffer::F64(ref v) = r1.output_arrays[0].data {
771            assert!((v[0] - 16.0).abs() < 1e-10, "partial avg DC = {}", v[0]);
772        }
773
774        let r2 = proc.process_array(&arr2, &pool);
775        assert_eq!(r2.output_arrays.len(), 1);
776        // After 2 frames: average = (16+32)/2 = 24, then reset
777        if let NDDataBuffer::F64(ref v) = r2.output_arrays[0].data {
778            assert!((v[0] - 24.0).abs() < 1e-10, "averaged DC = {}", v[0]);
779        }
780    }
781
782    #[test]
783    fn test_processor_averaging_dimension_change_resets() {
784        let config = FFTConfig {
785            mode: FFTMode::Rows1D,
786            direction: FFTDirection::Forward,
787            suppress_dc: false,
788            num_average: 3,
789        };
790        let mut proc = FFTProcessor::with_config(config);
791        let pool = NDArrayPool::new(0);
792
793        // Frame 1: width=8
794        let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
795        if let NDDataBuffer::F64(ref mut v) = arr1.data {
796            for i in 0..8 {
797                v[i] = 1.0;
798            }
799        }
800        let _ = proc.process_array(&arr1, &pool);
801        assert_eq!(proc.avg_count, 1);
802
803        // Frame 2: width=4 — dimension change should reset
804        let mut arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
805        if let NDDataBuffer::F64(ref mut v) = arr2.data {
806            for i in 0..4 {
807                v[i] = 1.0;
808            }
809        }
810        let _ = proc.process_array(&arr2, &pool);
811        // After dimension change, avg_count should be 1 (reset + one new frame)
812        assert_eq!(proc.avg_count, 1);
813    }
814
815    #[test]
816    fn test_fft_1d_multirow() {
817        // 2 rows, each a different constant
818        let w = 4;
819        let h = 2;
820        let mut arr = NDArray::new(
821            vec![NDDimension::new(w), NDDimension::new(h)],
822            NDDataType::Float64,
823        );
824        if let NDDataBuffer::F64(ref mut v) = arr.data {
825            // Row 0: all 1.0
826            for i in 0..w {
827                v[i] = 1.0;
828            }
829            // Row 1: all 3.0
830            for i in w..2 * w {
831                v[i] = 3.0;
832            }
833        }
834
835        let result = fft_1d_rows(&arr, false).unwrap();
836        if let NDDataBuffer::F64(ref v) = result.data {
837            // Row 0 DC = 4*1 = 4
838            assert!((v[0] - 4.0).abs() < 1e-10);
839            // Row 1 DC = 4*3 = 12
840            assert!((v[w] - 12.0).abs() < 1e-10);
841        } else {
842            panic!("expected F64 data");
843        }
844    }
845
846    #[test]
847    fn test_inverse_fft_1d() {
848        // IFFT of a known forward FFT should give back the original magnitudes
849        // For a real constant signal, forward FFT gives [N, 0, 0, ...0]
850        // IFFT of [N, 0, ...0] (real input) should give constant = 1.0 for each sample
851        let n = 8;
852        let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
853        if let NDDataBuffer::F64(ref mut v) = arr.data {
854            v[0] = 8.0; // DC = N
855            // rest are 0
856        }
857
858        let config = FFTConfig {
859            mode: FFTMode::Rows1D,
860            direction: FFTDirection::Inverse,
861            suppress_dc: false,
862            num_average: 0,
863        };
864        let mut proc = FFTProcessor::with_config(config);
865        let pool = NDArrayPool::new(0);
866
867        let result = proc.process_array(&arr, &pool);
868        assert_eq!(result.output_arrays.len(), 1);
869        if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
870            // Each sample should be magnitude 1.0 (8/8 = 1.0 after normalization)
871            for i in 0..n {
872                assert!(
873                    (v[i] - 1.0).abs() < 1e-10,
874                    "sample {} = {}, expected 1.0",
875                    i,
876                    v[i]
877                );
878            }
879        } else {
880            panic!("expected F64 data");
881        }
882    }
883
884    #[test]
885    fn test_fft_preserves_metadata() {
886        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
887        arr.unique_id = 42;
888        if let NDDataBuffer::F64(ref mut v) = arr.data {
889            v[0] = 1.0;
890        }
891
892        let result = fft_1d_rows(&arr, false).unwrap();
893        assert_eq!(result.unique_id, 42);
894        assert_eq!(result.timestamp, arr.timestamp);
895    }
896}