Skip to main content

scirs2_fft/
real_planner.rs

1//! Real FFT planner with trait object support
2//!
3//! This module provides trait object interfaces for real-to-complex and complex-to-real
4//! FFT operations, matching the API patterns used by realfft crate for easier migration.
5//! Uses OxiFFT as the backend (COOLJAPAN Pure Rust policy).
6//!
7//! # Features
8//!
9//! - `RealToComplex` trait for forward real-to-complex FFT operations
10//! - `ComplexToReal` trait for inverse complex-to-real FFT operations
11//! - `RealFftPlanner` for creating and caching FFT plans
12//! - Support for both f32 and f64 precision
13//! - Thread-safe plan caching with `Arc<dyn Trait>`
14//!
15//! # Examples
16//!
17//! ```
18//! use scirs2_fft::real_planner::{RealFftPlanner, RealToComplex, ComplexToReal};
19//! use std::sync::Arc;
20//!
21//! // Create a planner
22//! let mut planner = RealFftPlanner::<f64>::new();
23//!
24//! // Plan forward FFT
25//! let forward_fft = planner.plan_fft_forward(1024);
26//!
27//! // Plan inverse FFT
28//! let inverse_fft = planner.plan_fft_inverse(1024);
29//!
30//! // Use in struct (common VoiRS pattern)
31//! struct AudioProcessor {
32//!     forward: Arc<dyn RealToComplex<f64>>,
33//!     backward: Arc<dyn ComplexToReal<f64>>,
34//! }
35//! ```
36
37use crate::error::{FFTError, FFTResult};
38#[cfg(feature = "oxifft")]
39use crate::oxifft_plan_cache;
40#[cfg(feature = "oxifft")]
41use oxifft::{Complex as OxiComplex, Direction};
42use scirs2_core::numeric::Complex;
43use scirs2_core::numeric::Float;
44
45/// Trait for real-to-complex FFT operations
46///
47/// This trait defines the interface for forward FFT transforms that convert
48/// real-valued input data to complex-valued frequency domain output.
49pub trait RealToComplex<T: Float>: Send + Sync {
50    /// Process a real-valued input and produce complex-valued output
51    ///
52    /// # Arguments
53    ///
54    /// * `input` - Real-valued input samples
55    /// * `output` - Complex-valued frequency domain output (length = input.len()/2 + 1)
56    fn process(&self, input: &[T], output: &mut [Complex<T>]) -> FFTResult<()>;
57
58    /// Get the length of the input this FFT is configured for
59    fn len(&self) -> usize;
60
61    /// Check if this FFT is empty (length 0)
62    fn is_empty(&self) -> bool {
63        self.len() == 0
64    }
65
66    /// Get the length of the output this FFT produces
67    fn output_len(&self) -> usize {
68        self.len() / 2 + 1
69    }
70}
71
72/// Trait for complex-to-real FFT operations
73///
74/// This trait defines the interface for inverse FFT transforms that convert
75/// complex-valued frequency domain data back to real-valued time domain output.
76pub trait ComplexToReal<T: Float>: Send + Sync {
77    /// Process a complex-valued input and produce real-valued output
78    ///
79    /// # Arguments
80    ///
81    /// * `input` - Complex-valued frequency domain samples (length = output.len()/2 + 1)
82    /// * `output` - Real-valued time domain output
83    fn process(&self, input: &[Complex<T>], output: &mut [T]) -> FFTResult<()>;
84
85    /// Get the length of the output this IFFT is configured for
86    fn len(&self) -> usize;
87
88    /// Check if this IFFT is empty (length 0)
89    fn is_empty(&self) -> bool {
90        self.len() == 0
91    }
92
93    /// Get the length of the input this IFFT expects
94    fn input_len(&self) -> usize {
95        self.len() / 2 + 1
96    }
97}
98
99/// Real FFT plan implementation for f64 using OxiFFT backend
100struct RealFftPlanF64 {
101    length: usize,
102}
103
104impl RealFftPlanF64 {
105    fn new(length: usize) -> Self {
106        Self { length }
107    }
108}
109
110impl RealToComplex<f64> for RealFftPlanF64 {
111    fn process(&self, input: &[f64], output: &mut [Complex<f64>]) -> FFTResult<()> {
112        // Validate input/output lengths
113        if input.len() != self.length {
114            return Err(FFTError::ValueError(format!(
115                "Input length {} doesn't match plan length {}",
116                input.len(),
117                self.length
118            )));
119        }
120        if output.len() != self.output_len() {
121            return Err(FFTError::ValueError(format!(
122                "Output length {} doesn't match expected length {}",
123                output.len(),
124                self.output_len()
125            )));
126        }
127
128        #[cfg(feature = "oxifft")]
129        {
130            // Convert real input to complex for full FFT
131            let input_oxi: Vec<OxiComplex<f64>> =
132                input.iter().map(|&x| OxiComplex::new(x, 0.0)).collect();
133            let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
134
135            oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, Direction::Forward)?;
136
137            // Copy first n/2 + 1 elements to output (real FFT Hermitian symmetry property)
138            let out_len = self.output_len();
139            for (i, dst) in output.iter_mut().enumerate().take(out_len) {
140                *dst = Complex::new(output_oxi[i].re, output_oxi[i].im);
141            }
142        }
143
144        #[cfg(not(feature = "oxifft"))]
145        {
146            // Fallback: zero-fill output when no backend available
147            for dst in output.iter_mut() {
148                *dst = Complex::new(0.0, 0.0);
149            }
150        }
151
152        Ok(())
153    }
154
155    fn len(&self) -> usize {
156        self.length
157    }
158}
159
160/// Inverse real FFT plan implementation for f64 using OxiFFT backend
161struct InverseRealFftPlanF64 {
162    length: usize,
163}
164
165impl InverseRealFftPlanF64 {
166    fn new(length: usize) -> Self {
167        Self { length }
168    }
169}
170
171impl ComplexToReal<f64> for InverseRealFftPlanF64 {
172    fn process(&self, input: &[Complex<f64>], output: &mut [f64]) -> FFTResult<()> {
173        // Validate input/output lengths
174        if input.len() != self.input_len() {
175            return Err(FFTError::ValueError(format!(
176                "Input length {} doesn't match expected length {}",
177                input.len(),
178                self.input_len()
179            )));
180        }
181        if output.len() != self.length {
182            return Err(FFTError::ValueError(format!(
183                "Output length {} doesn't match plan length {}",
184                output.len(),
185                self.length
186            )));
187        }
188
189        #[cfg(feature = "oxifft")]
190        {
191            // Reconstruct full spectrum using Hermitian symmetry
192            let mut buffer_oxi: Vec<OxiComplex<f64>> = Vec::with_capacity(self.length);
193
194            // Add the provided half-spectrum
195            for &c in input.iter() {
196                buffer_oxi.push(OxiComplex::new(c.re, c.im));
197            }
198
199            // Add conjugate symmetric part
200            let start_idx = if self.length % 2 == 0 {
201                input.len() - 1
202            } else {
203                input.len()
204            };
205
206            for i in (1..start_idx).rev() {
207                buffer_oxi.push(OxiComplex::new(input[i].re, -input[i].im));
208            }
209
210            // Pad to full length if needed
211            while buffer_oxi.len() < self.length {
212                buffer_oxi.push(OxiComplex::new(0.0, 0.0));
213            }
214
215            let mut out_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
216
217            oxifft_plan_cache::execute_c2c(&buffer_oxi, &mut out_oxi, Direction::Backward)?;
218
219            // Extract real parts and normalize
220            let scale = 1.0 / self.length as f64;
221            for (i, dst) in output.iter_mut().enumerate() {
222                *dst = out_oxi[i].re * scale;
223            }
224        }
225
226        #[cfg(not(feature = "oxifft"))]
227        {
228            for dst in output.iter_mut() {
229                *dst = 0.0;
230            }
231        }
232
233        Ok(())
234    }
235
236    fn len(&self) -> usize {
237        self.length
238    }
239}
240
241/// Real FFT plan implementation for f32 using OxiFFT backend
242///
243/// OxiFFT operates on f64 internally; f32 input/output is converted.
244struct RealFftPlanF32 {
245    length: usize,
246}
247
248impl RealFftPlanF32 {
249    fn new(length: usize) -> Self {
250        Self { length }
251    }
252}
253
254impl RealToComplex<f32> for RealFftPlanF32 {
255    fn process(&self, input: &[f32], output: &mut [Complex<f32>]) -> FFTResult<()> {
256        // Validate input/output lengths
257        if input.len() != self.length {
258            return Err(FFTError::ValueError(format!(
259                "Input length {} doesn't match plan length {}",
260                input.len(),
261                self.length
262            )));
263        }
264        if output.len() != self.output_len() {
265            return Err(FFTError::ValueError(format!(
266                "Output length {} doesn't match expected length {}",
267                output.len(),
268                self.output_len()
269            )));
270        }
271
272        #[cfg(feature = "oxifft")]
273        {
274            // Convert f32 real input to f64 complex for OxiFFT
275            let input_oxi: Vec<OxiComplex<f64>> = input
276                .iter()
277                .map(|&x| OxiComplex::new(x as f64, 0.0))
278                .collect();
279            let mut output_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
280
281            oxifft_plan_cache::execute_c2c(&input_oxi, &mut output_oxi, Direction::Forward)?;
282
283            // Copy first n/2 + 1 elements with f64->f32 conversion
284            let out_len = self.output_len();
285            for (i, dst) in output.iter_mut().enumerate().take(out_len) {
286                *dst = Complex::new(output_oxi[i].re as f32, output_oxi[i].im as f32);
287            }
288        }
289
290        #[cfg(not(feature = "oxifft"))]
291        {
292            for dst in output.iter_mut() {
293                *dst = Complex::new(0.0f32, 0.0f32);
294            }
295        }
296
297        Ok(())
298    }
299
300    fn len(&self) -> usize {
301        self.length
302    }
303}
304
305/// Inverse real FFT plan implementation for f32 using OxiFFT backend
306struct InverseRealFftPlanF32 {
307    length: usize,
308}
309
310impl InverseRealFftPlanF32 {
311    fn new(length: usize) -> Self {
312        Self { length }
313    }
314}
315
316impl ComplexToReal<f32> for InverseRealFftPlanF32 {
317    fn process(&self, input: &[Complex<f32>], output: &mut [f32]) -> FFTResult<()> {
318        // Validate input/output lengths
319        if input.len() != self.input_len() {
320            return Err(FFTError::ValueError(format!(
321                "Input length {} doesn't match expected length {}",
322                input.len(),
323                self.input_len()
324            )));
325        }
326        if output.len() != self.length {
327            return Err(FFTError::ValueError(format!(
328                "Output length {} doesn't match plan length {}",
329                output.len(),
330                self.length
331            )));
332        }
333
334        #[cfg(feature = "oxifft")]
335        {
336            // Reconstruct full spectrum using Hermitian symmetry (with f32->f64 conversion)
337            let mut buffer_oxi: Vec<OxiComplex<f64>> = Vec::with_capacity(self.length);
338
339            for &c in input.iter() {
340                buffer_oxi.push(OxiComplex::new(c.re as f64, c.im as f64));
341            }
342
343            let start_idx = if self.length % 2 == 0 {
344                input.len() - 1
345            } else {
346                input.len()
347            };
348
349            for i in (1..start_idx).rev() {
350                buffer_oxi.push(OxiComplex::new(input[i].re as f64, -(input[i].im as f64)));
351            }
352
353            while buffer_oxi.len() < self.length {
354                buffer_oxi.push(OxiComplex::new(0.0, 0.0));
355            }
356
357            let mut out_oxi: Vec<OxiComplex<f64>> = vec![OxiComplex::new(0.0, 0.0); self.length];
358
359            oxifft_plan_cache::execute_c2c(&buffer_oxi, &mut out_oxi, Direction::Backward)?;
360
361            // Extract real parts, normalize, and convert back to f32
362            let scale = 1.0 / self.length as f64;
363            for (i, dst) in output.iter_mut().enumerate() {
364                *dst = (out_oxi[i].re * scale) as f32;
365            }
366        }
367
368        #[cfg(not(feature = "oxifft"))]
369        {
370            for dst in output.iter_mut() {
371                *dst = 0.0f32;
372            }
373        }
374
375        Ok(())
376    }
377
378    fn len(&self) -> usize {
379        self.length
380    }
381}
382
383/// Real FFT planner for creating and managing FFT plans
384///
385/// This planner creates reusable FFT plans optimized for real-valued input/output.
386/// Plans are thread-safe and can be shared across threads using Arc.
387/// Uses OxiFFT as the backend (COOLJAPAN Pure Rust policy).
388///
389/// # Type Parameters
390///
391/// * `T` - Float type (f32 or f64)
392///
393/// # Examples
394///
395/// ```
396/// use scirs2_fft::real_planner::RealFftPlanner;
397///
398/// let mut planner = RealFftPlanner::<f64>::new();
399/// let forward = planner.plan_fft_forward(1024);
400/// let inverse = planner.plan_fft_inverse(1024);
401/// ```
402pub struct RealFftPlanner<T: Float> {
403    _phantom: std::marker::PhantomData<T>,
404}
405
406impl RealFftPlanner<f64> {
407    /// Create a new planner for f64 precision
408    pub fn new() -> Self {
409        Self {
410            _phantom: std::marker::PhantomData,
411        }
412    }
413
414    /// Create a forward FFT plan for real-to-complex transformation
415    ///
416    /// # Arguments
417    ///
418    /// * `length` - Length of the input real-valued array
419    ///
420    /// # Returns
421    ///
422    /// Arc-wrapped trait object implementing RealToComplex
423    pub fn plan_fft_forward(&mut self, length: usize) -> std::sync::Arc<dyn RealToComplex<f64>> {
424        std::sync::Arc::new(RealFftPlanF64::new(length))
425    }
426
427    /// Create an inverse FFT plan for complex-to-real transformation
428    ///
429    /// # Arguments
430    ///
431    /// * `length` - Length of the output real-valued array
432    ///
433    /// # Returns
434    ///
435    /// Arc-wrapped trait object implementing ComplexToReal
436    pub fn plan_fft_inverse(&mut self, length: usize) -> std::sync::Arc<dyn ComplexToReal<f64>> {
437        std::sync::Arc::new(InverseRealFftPlanF64::new(length))
438    }
439}
440
441impl Default for RealFftPlanner<f64> {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447impl RealFftPlanner<f32> {
448    /// Create a new planner for f32 precision
449    pub fn new() -> Self {
450        Self {
451            _phantom: std::marker::PhantomData,
452        }
453    }
454
455    /// Create a forward FFT plan for real-to-complex transformation
456    ///
457    /// # Arguments
458    ///
459    /// * `length` - Length of the input real-valued array
460    ///
461    /// # Returns
462    ///
463    /// Arc-wrapped trait object implementing RealToComplex
464    pub fn plan_fft_forward(&mut self, length: usize) -> std::sync::Arc<dyn RealToComplex<f32>> {
465        std::sync::Arc::new(RealFftPlanF32::new(length))
466    }
467
468    /// Create an inverse FFT plan for complex-to-real transformation
469    ///
470    /// # Arguments
471    ///
472    /// * `length` - Length of the output real-valued array
473    ///
474    /// # Returns
475    ///
476    /// Arc-wrapped trait object implementing ComplexToReal
477    pub fn plan_fft_inverse(&mut self, length: usize) -> std::sync::Arc<dyn ComplexToReal<f32>> {
478        std::sync::Arc::new(InverseRealFftPlanF32::new(length))
479    }
480}
481
482impl Default for RealFftPlanner<f32> {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use scirs2_core::numeric::Complex64;
492    use std::f64::consts::PI;
493
494    #[test]
495    fn test_real_fft_planner_f64() {
496        let mut planner = RealFftPlanner::<f64>::new();
497        let forward = planner.plan_fft_forward(8);
498        let inverse = planner.plan_fft_inverse(8);
499
500        // Test input
501        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
502        let mut spectrum = vec![Complex64::new(0.0, 0.0); 5]; // 8/2 + 1 = 5
503
504        // Forward transform
505        forward
506            .process(&input, &mut spectrum)
507            .expect("Forward FFT failed");
508
509        // Check DC component
510        let sum: f64 = input.iter().sum();
511        assert!((spectrum[0].re - sum).abs() < 1e-10);
512        assert!(spectrum[0].im.abs() < 1e-10);
513
514        // Inverse transform
515        let mut recovered = vec![0.0; 8];
516        inverse
517            .process(&spectrum, &mut recovered)
518            .expect("Inverse FFT failed");
519
520        // Check round-trip accuracy
521        for (i, (&orig, &recov)) in input.iter().zip(recovered.iter()).enumerate() {
522            assert!(
523                (orig - recov).abs() < 1e-10,
524                "Mismatch at index {}: {} vs {}",
525                i,
526                orig,
527                recov
528            );
529        }
530    }
531
532    #[test]
533    fn test_real_fft_planner_f32() {
534        let mut planner = RealFftPlanner::<f32>::new();
535        let forward = planner.plan_fft_forward(8);
536        let inverse = planner.plan_fft_inverse(8);
537
538        // Test input
539        let input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
540        let mut spectrum = vec![Complex::new(0.0f32, 0.0); 5]; // 8/2 + 1 = 5
541
542        // Forward transform
543        forward
544            .process(&input, &mut spectrum)
545            .expect("Forward FFT failed");
546
547        // Inverse transform
548        let mut recovered = vec![0.0f32; 8];
549        inverse
550            .process(&spectrum, &mut recovered)
551            .expect("Inverse FFT failed");
552
553        // Check round-trip accuracy (lower precision for f32)
554        for (i, (&orig, &recov)) in input.iter().zip(recovered.iter()).enumerate() {
555            assert!(
556                (orig - recov).abs() < 1e-5,
557                "Mismatch at index {}: {} vs {}",
558                i,
559                orig,
560                recov
561            );
562        }
563    }
564
565    #[test]
566    fn test_sine_wave_fft() {
567        let mut planner = RealFftPlanner::<f64>::new();
568        let length = 128;
569        let forward = planner.plan_fft_forward(length);
570
571        // Generate sine wave at frequency bin 5
572        let freq_index = 5;
573        let input: Vec<f64> = (0..length)
574            .map(|i| (2.0 * PI * freq_index as f64 * i as f64 / length as f64).sin())
575            .collect();
576
577        let mut spectrum = vec![Complex64::new(0.0, 0.0); length / 2 + 1];
578        forward.process(&input, &mut spectrum).expect("FFT failed");
579
580        // Check that energy is concentrated at the expected frequency
581        let magnitudes: Vec<f64> = spectrum.iter().map(|c| c.norm()).collect();
582
583        // Find peak
584        let (peak_idx, &peak_mag) = magnitudes
585            .iter()
586            .enumerate()
587            .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
588            .expect("Operation failed");
589
590        assert_eq!(
591            peak_idx, freq_index,
592            "Peak should be at frequency index {}",
593            freq_index
594        );
595        assert!(peak_mag > length as f64 / 4.0, "Peak magnitude too small");
596    }
597
598    #[test]
599    fn test_plan_properties() {
600        let mut planner = RealFftPlanner::<f64>::new();
601        let forward = planner.plan_fft_forward(1024);
602
603        assert_eq!(forward.len(), 1024);
604        assert_eq!(forward.output_len(), 513); // 1024/2 + 1
605        assert!(!forward.is_empty());
606    }
607
608    #[test]
609    fn test_voirs_usage_pattern() {
610        // This test demonstrates the VoiRS usage pattern with Arc<dyn Trait>
611        struct AudioProcessor {
612            forward: std::sync::Arc<dyn RealToComplex<f64>>,
613            backward: std::sync::Arc<dyn ComplexToReal<f64>>,
614        }
615
616        impl AudioProcessor {
617            fn new(size: usize) -> Self {
618                let mut planner = RealFftPlanner::<f64>::new();
619                Self {
620                    forward: planner.plan_fft_forward(size),
621                    backward: planner.plan_fft_inverse(size),
622                }
623            }
624
625            fn process(&self, input: &[f64]) -> Vec<f64> {
626                let mut spectrum = vec![Complex64::new(0.0, 0.0); self.forward.output_len()];
627                self.forward
628                    .process(input, &mut spectrum)
629                    .expect("Forward FFT failed");
630
631                let mut output = vec![0.0; self.backward.len()];
632                self.backward
633                    .process(&spectrum, &mut output)
634                    .expect("Inverse FFT failed");
635
636                output
637            }
638        }
639
640        let processor = AudioProcessor::new(16);
641        let input = vec![
642            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
643        ];
644        let output = processor.process(&input);
645
646        // Verify round-trip
647        for (i, (&orig, &recov)) in input.iter().zip(output.iter()).enumerate() {
648            assert!((orig - recov).abs() < 1e-10, "Mismatch at {}", i);
649        }
650    }
651}