1use crate::algorithm_selector::{AlgorithmSelector, CacheInfo, FftAlgorithm, InputCharacteristics};
33use crate::error::{FFTError, FFTResult};
34#[cfg(feature = "oxifft")]
35use crate::oxifft_plan_cache;
36#[cfg(feature = "oxifft")]
37use oxifft::{Complex as OxiComplex, Direction};
38use scirs2_core::numeric::Complex64;
39use scirs2_core::numeric::NumCast;
40use std::fmt::Debug;
41
42#[derive(Debug, Clone)]
44pub struct LargeFftConfig {
45 pub max_block_size: usize,
47 pub target_memory_bytes: usize,
49 pub use_parallel: bool,
51 pub num_threads: usize,
53 pub cache_line_size: usize,
55 pub l1_cache_size: usize,
57 pub l2_cache_size: usize,
59 pub l3_cache_size: usize,
61 pub use_overlap_save: bool,
63 pub overlap_ratio: f64,
65}
66
67impl Default for LargeFftConfig {
68 fn default() -> Self {
69 let cache_info = CacheInfo::default();
70 Self {
71 max_block_size: 65536, target_memory_bytes: 256 * 1024 * 1024, use_parallel: true,
74 num_threads: num_cpus::get(),
75 cache_line_size: cache_info.cache_line_size,
76 l1_cache_size: cache_info.l1_size,
77 l2_cache_size: cache_info.l2_size,
78 l3_cache_size: cache_info.l3_size,
79 use_overlap_save: false,
80 overlap_ratio: 0.5,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum LargeFftMethod {
88 Direct,
90 CacheBlocked,
92 Streaming,
94 OutOfCore,
96}
97
98#[derive(Debug, Clone)]
100pub struct LargeFftStats {
101 pub method: LargeFftMethod,
103 pub num_blocks: usize,
105 pub block_size: usize,
107 pub peak_memory_bytes: usize,
109 pub total_time_ns: u64,
111}
112
113pub struct LargeFft {
115 config: LargeFftConfig,
117 selector: AlgorithmSelector,
119}
120
121impl Default for LargeFft {
122 fn default() -> Self {
123 Self::new(LargeFftConfig::default())
124 }
125}
126
127impl LargeFft {
128 pub fn with_defaults() -> Self {
130 Self::new(LargeFftConfig::default())
131 }
132
133 pub fn new(config: LargeFftConfig) -> Self {
135 Self {
136 config,
137 selector: AlgorithmSelector::new(),
138 }
139 }
140
141 pub fn select_method(&self, size: usize) -> LargeFftMethod {
143 let memory_required = size * 16; if memory_required <= self.config.l2_cache_size {
146 LargeFftMethod::Direct
147 } else if memory_required <= self.config.l3_cache_size {
148 LargeFftMethod::CacheBlocked
149 } else if memory_required <= self.config.target_memory_bytes {
150 LargeFftMethod::Streaming
151 } else {
152 LargeFftMethod::OutOfCore
153 }
154 }
155
156 pub fn compute<T>(&self, input: &[T], forward: bool) -> FFTResult<Vec<Complex64>>
158 where
159 T: NumCast + Copy + Debug + 'static,
160 {
161 let size = input.len();
162 if size == 0 {
163 return Err(FFTError::ValueError("Input cannot be empty".to_string()));
164 }
165
166 let method = self.select_method(size);
167
168 match method {
169 LargeFftMethod::Direct => self.compute_direct(input, forward),
170 LargeFftMethod::CacheBlocked => self.compute_cache_blocked(input, forward),
171 LargeFftMethod::Streaming => self.compute_streaming(input, forward),
172 LargeFftMethod::OutOfCore => self.compute_out_of_core(input, forward),
173 }
174 }
175
176 pub fn compute_complex(&self, input: &[Complex64], forward: bool) -> FFTResult<Vec<Complex64>> {
181 let size = input.len();
182 if size == 0 {
183 return Err(FFTError::ValueError("Input cannot be empty".to_string()));
184 }
185
186 self.compute_direct_complex(input, forward)
187 }
188
189 fn compute_direct_complex(
191 &self,
192 input: &[Complex64],
193 forward: bool,
194 ) -> FFTResult<Vec<Complex64>> {
195 let size = input.len();
196 let data: Vec<Complex64> = input.to_vec();
197
198 #[cfg(feature = "oxifft")]
199 {
200 let input_oxi: Vec<OxiComplex<f64>> =
201 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
202 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
203
204 let direction = if forward {
205 Direction::Forward
206 } else {
207 Direction::Backward
208 };
209 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
210
211 let mut result: Vec<Complex64> = output
212 .into_iter()
213 .map(|c| Complex64::new(c.re, c.im))
214 .collect();
215
216 if !forward {
217 let scale = 1.0 / size as f64;
218 for val in &mut result {
219 *val *= scale;
220 }
221 }
222
223 Ok(result)
224 }
225 }
226
227 fn compute_direct<T>(&self, input: &[T], forward: bool) -> FFTResult<Vec<Complex64>>
229 where
230 T: NumCast + Copy + Debug + 'static,
231 {
232 let size = input.len();
233
234 let data: Vec<Complex64> = input
236 .iter()
237 .map(|val| {
238 let real: f64 = NumCast::from(*val).unwrap_or(0.0);
239 Complex64::new(real, 0.0)
240 })
241 .collect();
242
243 #[cfg(feature = "oxifft")]
245 {
246 let input_oxi: Vec<OxiComplex<f64>> =
248 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
249 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
250
251 let direction = if forward {
253 Direction::Forward
254 } else {
255 Direction::Backward
256 };
257 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
258
259 let mut result: Vec<Complex64> = output
261 .into_iter()
262 .map(|c| Complex64::new(c.re, c.im))
263 .collect();
264
265 if !forward {
267 let scale = 1.0 / size as f64;
268 for val in &mut result {
269 *val *= scale;
270 }
271 }
272
273 Ok(result)
274 }
275 }
276
277 fn compute_cache_blocked<T>(&self, input: &[T], forward: bool) -> FFTResult<Vec<Complex64>>
279 where
280 T: NumCast + Copy + Debug + 'static,
281 {
282 let size = input.len();
283
284 let elements_per_cache = self.config.l2_cache_size / 16; let block_size = find_optimal_block_size(size, elements_per_cache);
287
288 let data: Vec<Complex64> = input
290 .iter()
291 .map(|val| {
292 let real: f64 = NumCast::from(*val).unwrap_or(0.0);
293 Complex64::new(real, 0.0)
294 })
295 .collect();
296
297 #[cfg(feature = "oxifft")]
305 {
306 let input_oxi: Vec<OxiComplex<f64>> =
308 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
309 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
310
311 let direction = if forward {
313 Direction::Forward
314 } else {
315 Direction::Backward
316 };
317 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
318
319 let mut result: Vec<Complex64> = output
321 .into_iter()
322 .map(|c| Complex64::new(c.re, c.im))
323 .collect();
324
325 if !forward {
327 let scale = 1.0 / size as f64;
328 for val in &mut result {
329 *val *= scale;
330 }
331 }
332
333 let _ = block_size;
335
336 Ok(result)
337 }
338 }
339
340 fn compute_streaming<T>(&self, input: &[T], forward: bool) -> FFTResult<Vec<Complex64>>
342 where
343 T: NumCast + Copy + Debug + 'static,
344 {
345 let size = input.len();
346
347 let chunk_size = self.config.max_block_size;
356 let mut data: Vec<Complex64> = Vec::with_capacity(size);
357
358 for chunk in input.chunks(chunk_size) {
359 for val in chunk {
360 let real: f64 = NumCast::from(*val).unwrap_or(0.0);
361 data.push(Complex64::new(real, 0.0));
362 }
363 }
364
365 #[cfg(feature = "oxifft")]
367 {
368 let input_oxi: Vec<OxiComplex<f64>> =
370 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
371 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
372
373 let direction = if forward {
375 Direction::Forward
376 } else {
377 Direction::Backward
378 };
379 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
380
381 let mut result: Vec<Complex64> = output
383 .into_iter()
384 .map(|c| Complex64::new(c.re, c.im))
385 .collect();
386
387 if !forward {
389 let scale = 1.0 / size as f64;
390 for chunk in result.chunks_mut(chunk_size) {
392 for val in chunk {
393 *val *= scale;
394 }
395 }
396 }
397
398 Ok(result)
399 }
400 }
401
402 fn compute_out_of_core<T>(&self, input: &[T], forward: bool) -> FFTResult<Vec<Complex64>>
404 where
405 T: NumCast + Copy + Debug + 'static,
406 {
407 eprintln!(
410 "Warning: Input size {} exceeds target memory, using streaming method",
411 input.len()
412 );
413 self.compute_streaming(input, forward)
414 }
415
416 pub fn compute_overlap_save<T>(
418 &self,
419 input: &[T],
420 filter_len: usize,
421 forward: bool,
422 ) -> FFTResult<Vec<Complex64>>
423 where
424 T: NumCast + Copy + Debug + 'static,
425 {
426 let input_len = input.len();
427
428 if filter_len == 0 {
429 return Err(FFTError::ValueError(
430 "Filter length must be positive".to_string(),
431 ));
432 }
433
434 let block_size = (self.config.max_block_size).max(filter_len * 4);
436 let fft_size = block_size.next_power_of_two();
437 let valid_output_per_block = fft_size - filter_len + 1;
438
439 let num_blocks = input_len.div_ceil(valid_output_per_block);
441
442 let output_len = input_len;
444 let mut output = Vec::with_capacity(output_len);
445
446 #[cfg(feature = "oxifft")]
448 {
449 let mut buffer = vec![Complex64::new(0.0, 0.0); fft_size];
451
452 for block_idx in 0..num_blocks {
453 let input_start = if block_idx == 0 {
454 0
455 } else {
456 block_idx * valid_output_per_block - (filter_len - 1)
457 };
458
459 for val in &mut buffer {
461 *val = Complex64::new(0.0, 0.0);
462 }
463
464 for (i, j) in (input_start..)
466 .take(fft_size.min(input_len - input_start))
467 .enumerate()
468 {
469 if j < input_len {
470 let real: f64 = NumCast::from(input[j]).unwrap_or(0.0);
471 buffer[i] = Complex64::new(real, 0.0);
472 }
473 }
474
475 let input_oxi: Vec<OxiComplex<f64>> =
477 buffer.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
478 let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); fft_size];
479 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, Direction::Forward)?;
480
481 for (i, val) in output_oxi.iter().enumerate() {
483 buffer[i] = Complex64::new(val.re, val.im);
484 }
485
486 if !forward {
490 let input_oxi: Vec<OxiComplex<f64>> =
492 buffer.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
493 let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); fft_size];
494 oxifft_plan_cache::execute_c2c(
495 &input_oxi,
496 &mut output_oxi,
497 Direction::Backward,
498 )?;
499
500 let scale = 1.0 / fft_size as f64;
502 for (i, val) in output_oxi.iter().enumerate() {
503 buffer[i] = Complex64::new(val.re * scale, val.im * scale);
504 }
505 }
506
507 let output_start = if block_idx == 0 { 0 } else { filter_len - 1 };
509 let output_count = valid_output_per_block.min(output_len - output.len());
510
511 for i in output_start..(output_start + output_count) {
512 if i < fft_size {
513 output.push(buffer[i]);
514 }
515 }
516 }
517
518 Ok(output)
519 }
520 }
521
522 pub fn config(&self) -> &LargeFftConfig {
524 &self.config
525 }
526
527 pub fn selector(&self) -> &AlgorithmSelector {
529 &self.selector
530 }
531
532 pub fn estimate_memory(&self, size: usize) -> usize {
534 let method = self.select_method(size);
535 let base_memory = size * 16; match method {
538 LargeFftMethod::Direct => base_memory * 2, LargeFftMethod::CacheBlocked => base_memory * 2 + self.config.l2_cache_size,
540 LargeFftMethod::Streaming => base_memory + self.config.max_block_size * 16,
541 LargeFftMethod::OutOfCore => self.config.target_memory_bytes,
542 }
543 }
544}
545
546fn find_optimal_block_size(total_size: usize, cache_elements: usize) -> usize {
548 let mut block = 1;
550 while block * 2 <= cache_elements && block * 2 <= total_size {
551 block *= 2;
552 }
553 block
554}
555
556pub struct LargeFftNd {
558 fft_1d: LargeFft,
560 config: LargeFftConfig,
562}
563
564impl Default for LargeFftNd {
565 fn default() -> Self {
566 Self::new(LargeFftConfig::default())
567 }
568}
569
570impl LargeFftNd {
571 pub fn new(config: LargeFftConfig) -> Self {
573 Self {
574 fft_1d: LargeFft::new(config.clone()),
575 config,
576 }
577 }
578
579 pub fn compute_2d<T>(
581 &self,
582 input: &[T],
583 rows: usize,
584 cols: usize,
585 forward: bool,
586 ) -> FFTResult<Vec<Complex64>>
587 where
588 T: NumCast + Copy + Debug + 'static,
589 {
590 if input.len() != rows * cols {
591 return Err(FFTError::ValueError(format!(
592 "Input size {} doesn't match dimensions {}x{}",
593 input.len(),
594 rows,
595 cols
596 )));
597 }
598
599 let mut data: Vec<Complex64> = input
601 .iter()
602 .map(|val| {
603 let real: f64 = NumCast::from(*val).unwrap_or(0.0);
604 Complex64::new(real, 0.0)
605 })
606 .collect();
607
608 let mut row_buffer = vec![Complex64::new(0.0, 0.0); cols];
610 for r in 0..rows {
611 for c in 0..cols {
613 row_buffer[c] = data[r * cols + c];
614 }
615
616 let row_fft = self.fft_1d.compute_direct(&row_buffer, forward)?;
618
619 for c in 0..cols {
621 data[r * cols + c] = row_fft[c];
622 }
623 }
624
625 let mut col_buffer = vec![Complex64::new(0.0, 0.0); rows];
627 for c in 0..cols {
628 for r in 0..rows {
630 col_buffer[r] = data[r * cols + c];
631 }
632
633 let col_fft = self.fft_1d.compute_direct(&col_buffer, forward)?;
635
636 for r in 0..rows {
638 data[r * cols + c] = col_fft[r];
639 }
640 }
641
642 Ok(data)
643 }
644
645 pub fn compute_nd<T>(
647 &self,
648 input: &[T],
649 shape: &[usize],
650 forward: bool,
651 ) -> FFTResult<Vec<Complex64>>
652 where
653 T: NumCast + Copy + Debug + 'static,
654 {
655 let total_size: usize = shape.iter().product();
656 if input.len() != total_size {
657 return Err(FFTError::ValueError(format!(
658 "Input size {} doesn't match shape {:?} (expected {})",
659 input.len(),
660 shape,
661 total_size
662 )));
663 }
664
665 let mut data: Vec<Complex64> = input
667 .iter()
668 .map(|val| {
669 let real: f64 = NumCast::from(*val).unwrap_or(0.0);
670 Complex64::new(real, 0.0)
671 })
672 .collect();
673
674 for (dim_idx, &dim_size) in shape.iter().enumerate() {
676 let stride = shape[(dim_idx + 1)..].iter().product::<usize>().max(1);
677 let outer_size = shape[..dim_idx].iter().product::<usize>().max(1);
678
679 let mut line_buffer = vec![Complex64::new(0.0, 0.0); dim_size];
680
681 for outer in 0..outer_size {
682 let outer_offset = outer * shape[dim_idx..].iter().product::<usize>().max(1);
683
684 for inner in 0..(total_size / (outer_size * dim_size)) {
685 for i in 0..dim_size {
687 let idx = outer_offset + i * stride + inner;
688 if idx < data.len() {
689 line_buffer[i] = data[idx];
690 }
691 }
692
693 let line_fft = self.fft_1d.compute_direct(&line_buffer, forward)?;
695
696 for i in 0..dim_size {
698 let idx = outer_offset + i * stride + inner;
699 if idx < data.len() {
700 data[idx] = line_fft[i];
701 }
702 }
703 }
704 }
705 }
706
707 Ok(data)
708 }
709}
710
711#[cfg(test)]
712mod tests {
713 use super::*;
714 use approx::assert_relative_eq;
715
716 #[test]
717 fn test_large_fft_direct() {
718 let large_fft = LargeFft::with_defaults();
719 let input: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
720
721 let result = large_fft.compute(&input, true).expect("FFT failed");
722
723 assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
725 }
726
727 #[test]
728 fn test_large_fft_roundtrip() {
729 let large_fft = LargeFft::with_defaults();
730 let input: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
731
732 let forward = large_fft.compute(&input, true).expect("Forward FFT failed");
733 let inverse = large_fft
734 .compute_complex(&forward, false)
735 .expect("Inverse FFT failed");
736
737 for (i, &orig) in input.iter().enumerate() {
739 assert_relative_eq!(inverse[i].re, orig, epsilon = 1e-10);
740 assert_relative_eq!(inverse[i].im, 0.0, epsilon = 1e-10);
741 }
742 }
743
744 #[test]
745 fn test_method_selection() {
746 let config = LargeFftConfig {
747 l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, target_memory_bytes: 256 * 1024 * 1024, ..Default::default()
751 };
752 let large_fft = LargeFft::new(config);
753
754 let method = large_fft.select_method(1024);
756 assert_eq!(method, LargeFftMethod::Direct);
757
758 let method = large_fft.select_method(100_000);
760 assert_eq!(method, LargeFftMethod::CacheBlocked);
761
762 let method = large_fft.select_method(1_000_000);
764 assert_eq!(method, LargeFftMethod::Streaming);
765 }
766
767 #[test]
768 fn test_large_fft_2d() {
769 let large_fft_nd = LargeFftNd::default();
770 let input: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
771
772 let result = large_fft_nd
773 .compute_2d(&input, 2, 2, true)
774 .expect("2D FFT failed");
775
776 assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
778 }
779
780 #[test]
781 fn test_memory_estimation() {
782 let large_fft = LargeFft::with_defaults();
783
784 let small_mem = large_fft.estimate_memory(1024);
785 let large_mem = large_fft.estimate_memory(1_000_000);
786
787 assert!(large_mem > small_mem);
788 }
789
790 #[test]
791 fn test_find_optimal_block_size() {
792 let block = find_optimal_block_size(65536, 16384);
793 assert!(block.is_power_of_two());
794 assert!(block <= 16384);
795 assert!(block <= 65536);
796 }
797}