1use std::sync::Arc;
2
3use ad_core_rs::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6use rustfft::FftPlanner;
7use rustfft::num_complex::Complex;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum FFTMode {
12 Rows1D,
13 Full2D,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum FFTDirection {
19 Forward,
20 Inverse,
21}
22
23pub struct FFTConfig {
25 pub mode: FFTMode,
26 pub direction: FFTDirection,
27 pub suppress_dc: bool,
29 pub num_average: usize,
31}
32
33impl Default for FFTConfig {
34 fn default() -> Self {
35 Self {
36 mode: FFTMode::Rows1D,
37 direction: FFTDirection::Forward,
38 suppress_dc: false,
39 num_average: 0,
40 }
41 }
42}
43
44pub fn fft_1d_rows(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
47 if src.dims.is_empty() {
48 return None;
49 }
50
51 let width = src.dims[0].size;
52 let height = if src.dims.len() >= 2 {
53 src.dims[1].size
54 } else {
55 1
56 };
57
58 if width == 0 {
59 return None;
60 }
61
62 let mut planner = FftPlanner::<f64>::new();
63 let fft = planner.plan_fft_forward(width);
64
65 let mut magnitudes = vec![0.0f64; width * height];
66 let mut row_buf = vec![Complex::new(0.0, 0.0); width];
67
68 for row in 0..height {
69 for i in 0..width {
71 row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
72 }
73
74 fft.process(&mut row_buf);
75
76 for (i, c) in row_buf.iter().enumerate() {
78 magnitudes[row * width + i] = c.norm();
79 }
80
81 if suppress_dc {
82 magnitudes[row * width] = 0.0;
83 }
84 }
85
86 let dims = src.dims.clone();
87 let mut arr = NDArray::new(dims, NDDataType::Float64);
88 arr.data = NDDataBuffer::F64(magnitudes);
89 arr.unique_id = src.unique_id;
90 arr.timestamp = src.timestamp;
91 arr.attributes = src.attributes.clone();
92 Some(arr)
93}
94
95pub fn fft_2d(src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
97 if src.dims.len() < 2 {
98 return None;
99 }
100
101 let w = src.dims[0].size;
102 let h = src.dims[1].size;
103
104 if w == 0 || h == 0 {
105 return None;
106 }
107
108 let mut planner = FftPlanner::<f64>::new();
109 let fft_row = planner.plan_fft_forward(w);
110 let fft_col = planner.plan_fft_forward(h);
111
112 let mut data = vec![Complex::new(0.0, 0.0); w * h];
114 let mut row_buf = vec![Complex::new(0.0, 0.0); w];
115
116 for row in 0..h {
117 for i in 0..w {
118 row_buf[i] = Complex::new(src.data.get_as_f64(row * w + i).unwrap_or(0.0), 0.0);
119 }
120 fft_row.process(&mut row_buf);
121 data[row * w..(row * w + w)].copy_from_slice(&row_buf);
122 }
123
124 let mut col_buf = vec![Complex::new(0.0, 0.0); h];
126
127 for col in 0..w {
128 for row in 0..h {
130 col_buf[row] = data[row * w + col];
131 }
132 fft_col.process(&mut col_buf);
133 for row in 0..h {
135 data[row * w + col] = col_buf[row];
136 }
137 }
138
139 let mut magnitudes: Vec<f64> = data.iter().map(|c| c.norm()).collect();
141
142 if suppress_dc {
143 magnitudes[0] = 0.0;
144 }
145
146 let dims = vec![NDDimension::new(w), NDDimension::new(h)];
147 let mut arr = NDArray::new(dims, NDDataType::Float64);
148 arr.data = NDDataBuffer::F64(magnitudes);
149 arr.unique_id = src.unique_id;
150 arr.timestamp = src.timestamp;
151 arr.attributes = src.attributes.clone();
152 Some(arr)
153}
154
155#[derive(Default)]
157struct FFTParamIndices {
158 direction: Option<usize>,
159 suppress_dc: Option<usize>,
160 num_average: Option<usize>,
161 num_averaged: Option<usize>,
162 reset_average: Option<usize>,
163}
164
165pub struct FFTProcessor {
166 config: FFTConfig,
167 planner: FftPlanner<f64>,
168 avg_buffer: Option<Vec<f64>>,
170 avg_count: usize,
172 cached_dims: Vec<usize>,
174 params: FFTParamIndices,
175}
176
177impl FFTProcessor {
178 pub fn new(mode: FFTMode) -> Self {
179 Self {
180 config: FFTConfig {
181 mode,
182 direction: FFTDirection::Forward,
183 suppress_dc: false,
184 num_average: 0,
185 },
186 planner: FftPlanner::new(),
187 avg_buffer: None,
188 avg_count: 0,
189 cached_dims: Vec::new(),
190 params: FFTParamIndices::default(),
191 }
192 }
193
194 pub fn with_config(config: FFTConfig) -> Self {
195 Self {
196 config,
197 planner: FftPlanner::new(),
198 avg_buffer: None,
199 avg_count: 0,
200 cached_dims: Vec::new(),
201 params: FFTParamIndices::default(),
202 }
203 }
204
205 fn check_dims_changed(&mut self, dims: &[NDDimension]) {
207 let current: Vec<usize> = dims.iter().map(|d| d.size).collect();
208 if current != self.cached_dims {
209 self.cached_dims = current;
210 self.avg_buffer = None;
211 self.avg_count = 0;
212 }
213 }
214
215 fn compute_fft(&mut self, src: &NDArray) -> Option<NDArray> {
217 let suppress_dc = self.config.suppress_dc;
218
219 match (self.config.mode, self.config.direction) {
220 (FFTMode::Rows1D, FFTDirection::Forward) => {
221 self.compute_fft_1d_rows_forward(src, suppress_dc)
222 }
223 (FFTMode::Rows1D, FFTDirection::Inverse) => {
224 self.compute_fft_1d_rows_inverse(src, suppress_dc)
225 }
226 (FFTMode::Full2D, FFTDirection::Forward) => {
227 self.compute_fft_2d_forward(src, suppress_dc)
228 }
229 (FFTMode::Full2D, FFTDirection::Inverse) => {
230 self.compute_fft_2d_inverse(src, suppress_dc)
231 }
232 }
233 }
234
235 fn compute_fft_1d_rows_forward(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
236 if src.dims.is_empty() {
237 return None;
238 }
239
240 let width = src.dims[0].size;
241 let height = if src.dims.len() >= 2 {
242 src.dims[1].size
243 } else {
244 1
245 };
246
247 if width == 0 {
248 return None;
249 }
250
251 let fft = self.planner.plan_fft_forward(width);
252
253 let mut magnitudes = vec![0.0f64; width * height];
254 let mut row_buf = vec![Complex::new(0.0, 0.0); width];
255
256 for row in 0..height {
257 for i in 0..width {
258 row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
259 }
260 fft.process(&mut row_buf);
261 for (i, c) in row_buf.iter().enumerate() {
262 magnitudes[row * width + i] = c.norm();
263 }
264 if suppress_dc {
265 magnitudes[row * width] = 0.0;
266 }
267 }
268
269 let dims = src.dims.clone();
270 let mut arr = NDArray::new(dims, NDDataType::Float64);
271 arr.data = NDDataBuffer::F64(magnitudes);
272 arr.unique_id = src.unique_id;
273 arr.timestamp = src.timestamp;
274 arr.attributes = src.attributes.clone();
275 Some(arr)
276 }
277
278 fn compute_fft_1d_rows_inverse(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
279 if src.dims.is_empty() {
280 return None;
281 }
282
283 let width = src.dims[0].size;
284 let height = if src.dims.len() >= 2 {
285 src.dims[1].size
286 } else {
287 1
288 };
289
290 if width == 0 {
291 return None;
292 }
293
294 let fft = self.planner.plan_fft_inverse(width);
295 let scale = 1.0 / width as f64;
296
297 let mut magnitudes = vec![0.0f64; width * height];
298 let mut row_buf = vec![Complex::new(0.0, 0.0); width];
299
300 for row in 0..height {
301 for i in 0..width {
302 row_buf[i] = Complex::new(src.data.get_as_f64(row * width + i).unwrap_or(0.0), 0.0);
303 }
304 if suppress_dc {
305 row_buf[0] = Complex::new(0.0, 0.0);
306 }
307 fft.process(&mut row_buf);
308 for (i, c) in row_buf.iter().enumerate() {
309 magnitudes[row * width + i] = c.norm() * scale;
310 }
311 }
312
313 let dims = src.dims.clone();
314 let mut arr = NDArray::new(dims, NDDataType::Float64);
315 arr.data = NDDataBuffer::F64(magnitudes);
316 arr.unique_id = src.unique_id;
317 arr.timestamp = src.timestamp;
318 arr.attributes = src.attributes.clone();
319 Some(arr)
320 }
321
322 fn compute_fft_2d_forward(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
323 if src.dims.len() < 2 {
324 return None;
325 }
326
327 let w = src.dims[0].size;
328 let h = src.dims[1].size;
329
330 if w == 0 || h == 0 {
331 return None;
332 }
333
334 let fft_row = self.planner.plan_fft_forward(w);
335 let fft_col = self.planner.plan_fft_forward(h);
336
337 let mut data = vec![Complex::new(0.0, 0.0); w * h];
338 let mut row_buf = vec![Complex::new(0.0, 0.0); w];
339
340 for row in 0..h {
341 for i in 0..w {
342 row_buf[i] = Complex::new(src.data.get_as_f64(row * w + i).unwrap_or(0.0), 0.0);
343 }
344 fft_row.process(&mut row_buf);
345 data[row * w..(row * w + w)].copy_from_slice(&row_buf);
346 }
347
348 let mut col_buf = vec![Complex::new(0.0, 0.0); h];
349 for col in 0..w {
350 for row in 0..h {
351 col_buf[row] = data[row * w + col];
352 }
353 fft_col.process(&mut col_buf);
354 for row in 0..h {
355 data[row * w + col] = col_buf[row];
356 }
357 }
358
359 let mut magnitudes: Vec<f64> = data.iter().map(|c| c.norm()).collect();
360
361 if suppress_dc {
362 magnitudes[0] = 0.0;
363 }
364
365 let dims = vec![NDDimension::new(w), NDDimension::new(h)];
366 let mut arr = NDArray::new(dims, NDDataType::Float64);
367 arr.data = NDDataBuffer::F64(magnitudes);
368 arr.unique_id = src.unique_id;
369 arr.timestamp = src.timestamp;
370 arr.attributes = src.attributes.clone();
371 Some(arr)
372 }
373
374 fn compute_fft_2d_inverse(&mut self, src: &NDArray, suppress_dc: bool) -> Option<NDArray> {
375 if src.dims.len() < 2 {
376 return None;
377 }
378
379 let w = src.dims[0].size;
380 let h = src.dims[1].size;
381
382 if w == 0 || h == 0 {
383 return None;
384 }
385
386 let fft_row = self.planner.plan_fft_inverse(w);
387 let fft_col = self.planner.plan_fft_inverse(h);
388 let scale = 1.0 / (w * h) as f64;
389
390 let mut data = vec![Complex::new(0.0, 0.0); w * h];
391 for i in 0..w * h {
392 data[i] = Complex::new(src.data.get_as_f64(i).unwrap_or(0.0), 0.0);
393 }
394
395 if suppress_dc {
396 data[0] = Complex::new(0.0, 0.0);
397 }
398
399 let mut col_buf = vec![Complex::new(0.0, 0.0); h];
400 for col in 0..w {
401 for row in 0..h {
402 col_buf[row] = data[row * w + col];
403 }
404 fft_col.process(&mut col_buf);
405 for row in 0..h {
406 data[row * w + col] = col_buf[row];
407 }
408 }
409
410 let mut row_buf = vec![Complex::new(0.0, 0.0); w];
411 for row in 0..h {
412 row_buf.copy_from_slice(&data[row * w..(row * w + w)]);
413 fft_row.process(&mut row_buf);
414 data[row * w..(row * w + w)].copy_from_slice(&row_buf);
415 }
416
417 let magnitudes: Vec<f64> = data.iter().map(|c| c.norm() * scale).collect();
418
419 let dims = vec![NDDimension::new(w), NDDimension::new(h)];
420 let mut arr = NDArray::new(dims, NDDataType::Float64);
421 arr.data = NDDataBuffer::F64(magnitudes);
422 arr.unique_id = src.unique_id;
423 arr.timestamp = src.timestamp;
424 arr.attributes = src.attributes.clone();
425 Some(arr)
426 }
427
428 fn apply_averaging(&mut self, magnitudes: &[f64]) -> Vec<f64> {
430 let num_avg = self.config.num_average;
431 if num_avg <= 1 {
432 return magnitudes.to_vec();
433 }
434
435 let buf = self
436 .avg_buffer
437 .get_or_insert_with(|| vec![0.0; magnitudes.len()]);
438
439 if buf.len() != magnitudes.len() {
441 *buf = vec![0.0; magnitudes.len()];
442 self.avg_count = 0;
443 }
444
445 for (b, &m) in buf.iter_mut().zip(magnitudes.iter()) {
447 *b += m;
448 }
449 self.avg_count += 1;
450
451 if self.avg_count >= num_avg {
452 let result: Vec<f64> = buf.iter().map(|&v| v / self.avg_count as f64).collect();
454 buf.iter_mut().for_each(|v| *v = 0.0);
455 self.avg_count = 0;
456 result
457 } else {
458 buf.iter().map(|&v| v / self.avg_count as f64).collect()
460 }
461 }
462}
463
464impl NDPluginProcess for FFTProcessor {
465 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
466 use ad_core_rs::plugin::runtime::ParamUpdate;
467
468 self.check_dims_changed(&array.dims);
469
470 let result = self.compute_fft(array);
471 let mut updates = Vec::new();
472 if let Some(idx) = self.params.num_averaged {
473 updates.push(ParamUpdate::int32(idx, self.avg_count as i32));
474 }
475
476 match result {
477 Some(mut out) => {
478 if self.config.num_average > 1 {
479 if let NDDataBuffer::F64(ref mags) = out.data {
480 let averaged = self.apply_averaging(mags);
481 out.data = NDDataBuffer::F64(averaged);
482 }
483 }
484 let mut r = ProcessResult::arrays(vec![Arc::new(out)]);
485 r.param_updates = updates;
486 r
487 }
488 None => ProcessResult::sink(updates),
489 }
490 }
491
492 fn plugin_type(&self) -> &str {
493 "NDPluginFFT"
494 }
495
496 fn register_params(
497 &mut self,
498 base: &mut asyn_rs::port::PortDriverBase,
499 ) -> asyn_rs::error::AsynResult<()> {
500 use asyn_rs::param::ParamType;
501 base.create_param("FFT_TIME_PER_POINT", ParamType::Float64)?;
502 base.create_param("FFT_TIME_AXIS", ParamType::Float64Array)?;
503 base.create_param("FFT_FREQ_AXIS", ParamType::Float64Array)?;
504 base.create_param("FFT_DIRECTION", ParamType::Int32)?;
505 base.create_param("FFT_SUPPRESS_DC", ParamType::Int32)?;
506 base.create_param("FFT_NUM_AVERAGE", ParamType::Int32)?;
507 base.create_param("FFT_NUM_AVERAGED", ParamType::Int32)?;
508 base.create_param("FFT_RESET_AVERAGE", ParamType::Int32)?;
509 base.create_param("FFT_TIME_SERIES", ParamType::Float64Array)?;
510 base.create_param("FFT_REAL", ParamType::Float64Array)?;
511 base.create_param("FFT_IMAGINARY", ParamType::Float64Array)?;
512 base.create_param("FFT_ABS_VALUE", ParamType::Float64Array)?;
513
514 self.params.direction = base.find_param("FFT_DIRECTION");
515 self.params.suppress_dc = base.find_param("FFT_SUPPRESS_DC");
516 self.params.num_average = base.find_param("FFT_NUM_AVERAGE");
517 self.params.num_averaged = base.find_param("FFT_NUM_AVERAGED");
518 self.params.reset_average = base.find_param("FFT_RESET_AVERAGE");
519 Ok(())
520 }
521
522 fn on_param_change(
523 &mut self,
524 reason: usize,
525 params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
526 ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
527 if Some(reason) == self.params.direction {
528 self.config.direction = if params.value.as_i32() == 0 {
529 FFTDirection::Forward
530 } else {
531 FFTDirection::Inverse
532 };
533 } else if Some(reason) == self.params.suppress_dc {
534 self.config.suppress_dc = params.value.as_i32() != 0;
535 } else if Some(reason) == self.params.num_average {
536 self.config.num_average = params.value.as_i32().max(0) as usize;
537 } else if Some(reason) == self.params.reset_average {
538 if params.value.as_i32() != 0 {
539 self.avg_buffer = None;
540 self.avg_count = 0;
541 }
542 }
543 ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
552 fn test_fft_1d_dc() {
553 let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
555 if let NDDataBuffer::F64(ref mut v) = arr.data {
556 for i in 0..8 {
557 v[i] = 1.0;
558 }
559 }
560
561 let result = fft_1d_rows(&arr, false).unwrap();
562 if let NDDataBuffer::F64(ref v) = result.data {
563 assert!((v[0] - 8.0).abs() < 1e-10);
565 assert!(v[1].abs() < 1e-10);
567 }
568 }
569
570 #[test]
571 fn test_fft_1d_sine() {
572 let n = 16;
574 let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
575 if let NDDataBuffer::F64(ref mut v) = arr.data {
576 for i in 0..n {
577 v[i] = (2.0 * std::f64::consts::PI * i as f64 / n as f64).sin();
578 }
579 }
580
581 let result = fft_1d_rows(&arr, false).unwrap();
582 if let NDDataBuffer::F64(ref v) = result.data {
583 assert!(v[0].abs() < 1e-10);
585 assert!(v[1] > 7.0);
587 assert!(v[2].abs() < 1e-10);
589 }
590 }
591
592 #[test]
593 fn test_fft_2d_dimensions() {
594 let arr = NDArray::new(
595 vec![NDDimension::new(4), NDDimension::new(4)],
596 NDDataType::UInt8,
597 );
598 let result = fft_2d(&arr, false).unwrap();
599 assert_eq!(result.dims[0].size, 4);
600 assert_eq!(result.dims[1].size, 4);
601 assert_eq!(result.data.data_type(), NDDataType::Float64);
602 }
603
604 #[test]
605 fn test_fft_1d_suppress_dc() {
606 let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
608 if let NDDataBuffer::F64(ref mut v) = arr.data {
609 for i in 0..8 {
610 v[i] = 1.0;
611 }
612 }
613
614 let result = fft_1d_rows(&arr, true).unwrap();
615 if let NDDataBuffer::F64(ref v) = result.data {
616 assert!((v[0]).abs() < 1e-15);
618 assert!(v[1].abs() < 1e-10);
620 } else {
621 panic!("expected F64 data");
622 }
623 }
624
625 #[test]
626 fn test_fft_2d_suppress_dc() {
627 let mut arr = NDArray::new(
629 vec![NDDimension::new(4), NDDimension::new(4)],
630 NDDataType::Float64,
631 );
632 if let NDDataBuffer::F64(ref mut v) = arr.data {
633 for val in v.iter_mut() {
634 *val = 3.0;
635 }
636 }
637
638 let result = fft_2d(&arr, true).unwrap();
639 if let NDDataBuffer::F64(ref v) = result.data {
640 assert!((v[0]).abs() < 1e-15);
642 } else {
643 panic!("expected F64 data");
644 }
645 }
646
647 #[test]
648 fn test_fft_2d_known_dc() {
649 let mut arr = NDArray::new(
651 vec![NDDimension::new(4), NDDimension::new(4)],
652 NDDataType::Float64,
653 );
654 if let NDDataBuffer::F64(ref mut v) = arr.data {
655 for val in v.iter_mut() {
656 *val = 2.0;
657 }
658 }
659
660 let result = fft_2d(&arr, false).unwrap();
661 if let NDDataBuffer::F64(ref v) = result.data {
662 assert!((v[0] - 32.0).abs() < 1e-10, "DC = {}, expected 32", v[0]);
663 for i in 1..v.len() {
665 assert!(v[i].abs() < 1e-10, "bin {} = {}, expected ~0", i, v[i]);
666 }
667 } else {
668 panic!("expected F64 data");
669 }
670 }
671
672 #[test]
673 fn test_fft_1d_known_cosine_peaks() {
674 let n = 16;
676 let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
677 if let NDDataBuffer::F64(ref mut v) = arr.data {
678 for i in 0..n {
679 v[i] = (2.0 * std::f64::consts::PI * 3.0 * i as f64 / n as f64).cos();
680 }
681 }
682
683 let result = fft_1d_rows(&arr, false).unwrap();
684 if let NDDataBuffer::F64(ref v) = result.data {
685 assert!(v[0].abs() < 1e-10);
687 assert!(
689 (v[3] - 8.0).abs() < 1e-10,
690 "k=3 magnitude = {}, expected 8",
691 v[3]
692 );
693 assert!(
694 (v[13] - 8.0).abs() < 1e-10,
695 "k=13 magnitude = {}, expected 8",
696 v[13]
697 );
698 for k in [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15] {
700 assert!(
701 v[k].abs() < 1e-10,
702 "k={} magnitude = {}, expected ~0",
703 k,
704 v[k]
705 );
706 }
707 } else {
708 panic!("expected F64 data");
709 }
710 }
711
712 #[test]
713 fn test_processor_with_config() {
714 let config = FFTConfig {
715 mode: FFTMode::Rows1D,
716 direction: FFTDirection::Forward,
717 suppress_dc: true,
718 num_average: 0,
719 };
720 let mut proc = FFTProcessor::with_config(config);
721 let pool = NDArrayPool::new(0);
722
723 let mut arr = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
724 if let NDDataBuffer::F64(ref mut v) = arr.data {
725 for i in 0..8 {
726 v[i] = 5.0;
727 }
728 }
729
730 let result = proc.process_array(&arr, &pool);
731 assert_eq!(result.output_arrays.len(), 1);
732 if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
733 assert!(v[0].abs() < 1e-15);
735 } else {
736 panic!("expected F64 data");
737 }
738 }
739
740 #[test]
741 fn test_processor_averaging() {
742 let config = FFTConfig {
743 mode: FFTMode::Rows1D,
744 direction: FFTDirection::Forward,
745 suppress_dc: false,
746 num_average: 2,
747 };
748 let mut proc = FFTProcessor::with_config(config);
749 let pool = NDArrayPool::new(0);
750
751 let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
753 if let NDDataBuffer::F64(ref mut v) = arr1.data {
754 for i in 0..8 {
755 v[i] = 2.0;
756 }
757 }
758
759 let mut arr2 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
761 if let NDDataBuffer::F64(ref mut v) = arr2.data {
762 for i in 0..8 {
763 v[i] = 4.0;
764 }
765 }
766
767 let r1 = proc.process_array(&arr1, &pool);
768 assert_eq!(r1.output_arrays.len(), 1);
769 if let NDDataBuffer::F64(ref v) = r1.output_arrays[0].data {
771 assert!((v[0] - 16.0).abs() < 1e-10, "partial avg DC = {}", v[0]);
772 }
773
774 let r2 = proc.process_array(&arr2, &pool);
775 assert_eq!(r2.output_arrays.len(), 1);
776 if let NDDataBuffer::F64(ref v) = r2.output_arrays[0].data {
778 assert!((v[0] - 24.0).abs() < 1e-10, "averaged DC = {}", v[0]);
779 }
780 }
781
782 #[test]
783 fn test_processor_averaging_dimension_change_resets() {
784 let config = FFTConfig {
785 mode: FFTMode::Rows1D,
786 direction: FFTDirection::Forward,
787 suppress_dc: false,
788 num_average: 3,
789 };
790 let mut proc = FFTProcessor::with_config(config);
791 let pool = NDArrayPool::new(0);
792
793 let mut arr1 = NDArray::new(vec![NDDimension::new(8)], NDDataType::Float64);
795 if let NDDataBuffer::F64(ref mut v) = arr1.data {
796 for i in 0..8 {
797 v[i] = 1.0;
798 }
799 }
800 let _ = proc.process_array(&arr1, &pool);
801 assert_eq!(proc.avg_count, 1);
802
803 let mut arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
805 if let NDDataBuffer::F64(ref mut v) = arr2.data {
806 for i in 0..4 {
807 v[i] = 1.0;
808 }
809 }
810 let _ = proc.process_array(&arr2, &pool);
811 assert_eq!(proc.avg_count, 1);
813 }
814
815 #[test]
816 fn test_fft_1d_multirow() {
817 let w = 4;
819 let h = 2;
820 let mut arr = NDArray::new(
821 vec![NDDimension::new(w), NDDimension::new(h)],
822 NDDataType::Float64,
823 );
824 if let NDDataBuffer::F64(ref mut v) = arr.data {
825 for i in 0..w {
827 v[i] = 1.0;
828 }
829 for i in w..2 * w {
831 v[i] = 3.0;
832 }
833 }
834
835 let result = fft_1d_rows(&arr, false).unwrap();
836 if let NDDataBuffer::F64(ref v) = result.data {
837 assert!((v[0] - 4.0).abs() < 1e-10);
839 assert!((v[w] - 12.0).abs() < 1e-10);
841 } else {
842 panic!("expected F64 data");
843 }
844 }
845
846 #[test]
847 fn test_inverse_fft_1d() {
848 let n = 8;
852 let mut arr = NDArray::new(vec![NDDimension::new(n)], NDDataType::Float64);
853 if let NDDataBuffer::F64(ref mut v) = arr.data {
854 v[0] = 8.0; }
857
858 let config = FFTConfig {
859 mode: FFTMode::Rows1D,
860 direction: FFTDirection::Inverse,
861 suppress_dc: false,
862 num_average: 0,
863 };
864 let mut proc = FFTProcessor::with_config(config);
865 let pool = NDArrayPool::new(0);
866
867 let result = proc.process_array(&arr, &pool);
868 assert_eq!(result.output_arrays.len(), 1);
869 if let NDDataBuffer::F64(ref v) = result.output_arrays[0].data {
870 for i in 0..n {
872 assert!(
873 (v[i] - 1.0).abs() < 1e-10,
874 "sample {} = {}, expected 1.0",
875 i,
876 v[i]
877 );
878 }
879 } else {
880 panic!("expected F64 data");
881 }
882 }
883
884 #[test]
885 fn test_fft_preserves_metadata() {
886 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::Float64);
887 arr.unique_id = 42;
888 if let NDDataBuffer::F64(ref mut v) = arr.data {
889 v[0] = 1.0;
890 }
891
892 let result = fft_1d_rows(&arr, false).unwrap();
893 assert_eq!(result.unique_id, 42);
894 assert_eq!(result.timestamp, arr.timestamp);
895 }
896}