Skip to main content

circulant_rs/fft/
rustfft_backend.rs

1// @module: crate::fft::rustfft_backend
2// @status: stable
3// @owner: code_expert
4// @feature: none
5// @depends: [crate::error, crate::fft::backend, crate::traits, rustfft, num_complex]
6// @tests: [unit]
7
8//! RustFFT backend implementation.
9
10use crate::error::{CirculantError, Result};
11use crate::fft::backend::FftBackend;
12use crate::traits::Scalar;
13use num_complex::Complex;
14use rustfft::{Fft, FftPlanner};
15use std::sync::Arc;
16
17/// FFT backend using the rustfft crate.
18///
19/// This is the default FFT backend for circulant-rs.
20pub struct RustFftBackend<T: Scalar + rustfft::FftNum> {
21    forward: Arc<dyn Fft<T>>,
22    inverse: Arc<dyn Fft<T>>,
23    size: usize,
24    scratch_len: usize,
25}
26
27impl<T: Scalar + rustfft::FftNum> RustFftBackend<T> {
28    /// Create a new RustFFT backend for the given size.
29    ///
30    /// # Errors
31    ///
32    /// Returns `InvalidFftSize` if size is 0.
33    pub fn new(size: usize) -> Result<Self> {
34        if size == 0 {
35            return Err(CirculantError::InvalidFftSize(0));
36        }
37
38        let mut planner = FftPlanner::new();
39        let forward = planner.plan_fft_forward(size);
40        let inverse = planner.plan_fft_inverse(size);
41
42        let scratch_len = forward
43            .get_inplace_scratch_len()
44            .max(inverse.get_inplace_scratch_len());
45
46        Ok(Self {
47            forward,
48            inverse,
49            size,
50            scratch_len,
51        })
52    }
53}
54
55impl<T: Scalar + rustfft::FftNum> FftBackend<T> for RustFftBackend<T> {
56    fn fft_forward(&self, buffer: &mut [Complex<T>]) {
57        let mut scratch = vec![Complex::new(T::zero(), T::zero()); self.scratch_len];
58        self.forward.process_with_scratch(buffer, &mut scratch);
59    }
60
61    fn fft_inverse(&self, buffer: &mut [Complex<T>]) {
62        let mut scratch = vec![Complex::new(T::zero(), T::zero()); self.scratch_len];
63        self.inverse.process_with_scratch(buffer, &mut scratch);
64
65        // Normalize by 1/N
66        let scale = T::one() / T::from(self.size).unwrap_or_else(|| T::one());
67        for val in buffer.iter_mut() {
68            *val = Complex::new(val.re * scale, val.im * scale);
69        }
70    }
71
72    fn size(&self) -> usize {
73        self.size
74    }
75
76    fn make_scratch(&self) -> Vec<Complex<T>> {
77        vec![Complex::new(T::zero(), T::zero()); self.scratch_len]
78    }
79}
80
81// Implement Send + Sync since Arc<dyn Fft<T>> is Send + Sync
82unsafe impl<T: Scalar + rustfft::FftNum> Send for RustFftBackend<T> {}
83unsafe impl<T: Scalar + rustfft::FftNum> Sync for RustFftBackend<T> {}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use approx::assert_relative_eq;
89    use std::f64::consts::PI;
90
91    #[test]
92    fn test_fft_forward_inverse_roundtrip() {
93        let backend = RustFftBackend::<f64>::new(8).unwrap();
94
95        // Original signal: [1, 2, 3, 4, 5, 6, 7, 8]
96        let original: Vec<Complex<f64>> = (1..=8).map(|x| Complex::new(x as f64, 0.0)).collect();
97
98        let mut buffer = original.clone();
99
100        // Forward FFT
101        backend.fft_forward(&mut buffer);
102
103        // Inverse FFT
104        backend.fft_inverse(&mut buffer);
105
106        // Should match original
107        for (orig, result) in original.iter().zip(buffer.iter()) {
108            assert_relative_eq!(orig.re, result.re, epsilon = 1e-10);
109            assert_relative_eq!(orig.im, result.im, epsilon = 1e-10);
110        }
111    }
112
113    #[test]
114    fn test_fft_known_values() {
115        // DFT of [1, 0, 0, 0] should be [1, 1, 1, 1]
116        let backend = RustFftBackend::<f64>::new(4).unwrap();
117
118        let mut buffer = vec![
119            Complex::new(1.0, 0.0),
120            Complex::new(0.0, 0.0),
121            Complex::new(0.0, 0.0),
122            Complex::new(0.0, 0.0),
123        ];
124
125        backend.fft_forward(&mut buffer);
126
127        for val in &buffer {
128            assert_relative_eq!(val.re, 1.0, epsilon = 1e-10);
129            assert_relative_eq!(val.im, 0.0, epsilon = 1e-10);
130        }
131    }
132
133    #[test]
134    fn test_fft_known_values_sinusoid() {
135        // DFT of a pure sinusoid should have peaks at the frequency bins
136        let n = 8;
137        let backend = RustFftBackend::<f64>::new(n).unwrap();
138
139        // cos(2*pi*k/N) for k=0..N-1, with frequency 1
140        let mut buffer: Vec<Complex<f64>> = (0..n)
141            .map(|k| {
142                let theta = 2.0 * PI * (k as f64) / (n as f64);
143                Complex::new(theta.cos(), 0.0)
144            })
145            .collect();
146
147        backend.fft_forward(&mut buffer);
148
149        // For cos, expect peaks at bins 1 and N-1 (which is 7)
150        // Bin 0 should be ~0 (DC component)
151        // Bin 1 and 7 should be ~N/2 = 4
152        assert_relative_eq!(buffer[0].re, 0.0, epsilon = 1e-10);
153        assert_relative_eq!(buffer[1].re, 4.0, epsilon = 1e-10);
154        assert_relative_eq!(buffer[7].re, 4.0, epsilon = 1e-10);
155
156        // Other bins should be ~0
157        for i in [2, 3, 4, 5, 6] {
158            assert_relative_eq!(buffer[i].re.abs(), 0.0, epsilon = 1e-10);
159            assert_relative_eq!(buffer[i].im.abs(), 0.0, epsilon = 1e-10);
160        }
161    }
162
163    #[test]
164    fn test_fft_linearity() {
165        let backend = RustFftBackend::<f64>::new(4).unwrap();
166
167        let x: Vec<Complex<f64>> = vec![
168            Complex::new(1.0, 0.0),
169            Complex::new(2.0, 0.0),
170            Complex::new(3.0, 0.0),
171            Complex::new(4.0, 0.0),
172        ];
173        let y: Vec<Complex<f64>> = vec![
174            Complex::new(5.0, 0.0),
175            Complex::new(6.0, 0.0),
176            Complex::new(7.0, 0.0),
177            Complex::new(8.0, 0.0),
178        ];
179
180        // Compute FFT(x) and FFT(y)
181        let mut fft_x = x.clone();
182        let mut fft_y = y.clone();
183        backend.fft_forward(&mut fft_x);
184        backend.fft_forward(&mut fft_y);
185
186        // Compute FFT(x + y)
187        let mut sum: Vec<Complex<f64>> = x.iter().zip(y.iter()).map(|(a, b)| a + b).collect();
188        backend.fft_forward(&mut sum);
189
190        // FFT(x + y) should equal FFT(x) + FFT(y)
191        for i in 0..4 {
192            let expected = fft_x[i] + fft_y[i];
193            assert_relative_eq!(sum[i].re, expected.re, epsilon = 1e-10);
194            assert_relative_eq!(sum[i].im, expected.im, epsilon = 1e-10);
195        }
196
197        // Test scaling: FFT(alpha * x) = alpha * FFT(x)
198        let alpha = Complex::new(2.5, 0.0);
199        let mut scaled_x: Vec<Complex<f64>> = x.iter().map(|v| alpha * v).collect();
200        backend.fft_forward(&mut scaled_x);
201
202        let mut original_fft = x.clone();
203        backend.fft_forward(&mut original_fft);
204
205        for i in 0..4 {
206            let expected = alpha * original_fft[i];
207            assert_relative_eq!(scaled_x[i].re, expected.re, epsilon = 1e-10);
208            assert_relative_eq!(scaled_x[i].im, expected.im, epsilon = 1e-10);
209        }
210    }
211
212    #[test]
213    fn test_fft_size() {
214        let backend = RustFftBackend::<f64>::new(16).unwrap();
215        assert_eq!(backend.size(), 16);
216    }
217
218    #[test]
219    fn test_fft_zero_size_returns_error() {
220        let result = RustFftBackend::<f64>::new(0);
221        assert!(matches!(
222            result,
223            Err(crate::error::CirculantError::InvalidFftSize(0))
224        ));
225    }
226
227    #[test]
228    fn test_fft_f32() {
229        let backend = RustFftBackend::<f32>::new(4).unwrap();
230
231        let original: Vec<Complex<f32>> = vec![
232            Complex::new(1.0, 0.0),
233            Complex::new(2.0, 0.0),
234            Complex::new(3.0, 0.0),
235            Complex::new(4.0, 0.0),
236        ];
237
238        let mut buffer = original.clone();
239        backend.fft_forward(&mut buffer);
240        backend.fft_inverse(&mut buffer);
241
242        for (orig, result) in original.iter().zip(buffer.iter()) {
243            assert_relative_eq!(orig.re, result.re, epsilon = 1e-5);
244            assert_relative_eq!(orig.im, result.im, epsilon = 1e-5);
245        }
246    }
247}